[GH-PAGES] Updated website
@@ -23,7 +23,7 @@ You will specifically learn about:
|
|||||||
# yourself with Triton, in a way that is easy to customize and extend.
|
# yourself with Triton, in a way that is easy to customize and extend.
|
||||||
#
|
#
|
||||||
# Roughly speaking, the kernel that we will write will implement the following blocked
|
# Roughly speaking, the kernel that we will write will implement the following blocked
|
||||||
# algorithm to multiply a (MxK) by a (KxN) matrix:
|
# algorithm to multiply a (M, K) by a (K, N) matrix:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
@@ -38,7 +38,7 @@ You will specifically learn about:
|
|||||||
# acc += dot(a, b)
|
# acc += dot(a, b)
|
||||||
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||||
#
|
#
|
||||||
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Compute Kernel
|
# Compute Kernel
|
||||||
@@ -53,35 +53,31 @@ You will specifically learn about:
|
|||||||
# ~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~
|
||||||
#
|
#
|
||||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||||
# y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
# y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
|
||||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||||
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||||
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||||
#
|
#
|
||||||
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pid_m = triton.program_id(0)
|
# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
# pid_n = triton.program_id(1)
|
# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
# rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)
|
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
# rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)
|
# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||||
# rk = triton.arange(0, BLOCK_SIZE_K)
|
# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||||
# // pointer for A operand
|
|
||||||
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
|
||||||
# // pointer for B operand
|
|
||||||
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
|
||||||
#
|
#
|
||||||
# And then updated in the inner loop as follows:
|
# And then updated in the inner loop as follows:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pa += BLOCK_SIZE_K * stride_a_1;
|
# pa += BLOCK_SIZE_K * stride_ak;
|
||||||
# pb += BLOCK_SIZE_K * stride_b_0;
|
# pb += BLOCK_SIZE_K * stride_bk;
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# L2 Cache Optimizations
|
# L2 Cache Optimizations
|
||||||
@@ -109,13 +105,25 @@ You will specifically learn about:
|
|||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pid = triton.program_id(0);
|
# # program ID
|
||||||
# width = GROUP_M * grid_n;
|
# pid = tl.program_id(axis=0)
|
||||||
# group_id = pid // width;
|
# # number of program ids along the M axis
|
||||||
# # we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0
|
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
# # number of programs ids along the N axis
|
||||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
# pid_n = (pid % width) // (group_size);
|
# # number of programs in group
|
||||||
|
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
# # id of the group this program is in
|
||||||
|
# group_id = pid // num_pid_in_group
|
||||||
|
# # row-id of the first program in the group
|
||||||
|
# first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
||||||
|
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
# # *within groups*, programs are ordered in a column-major order
|
||||||
|
# # row-id of the program in the *launch grid*
|
||||||
|
# pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
# # col-id of the program in the *launch grid*
|
||||||
|
# pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
#
|
#
|
||||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||||
@@ -164,26 +172,19 @@ import triton.language as tl
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def matmul_kernel(
|
def matmul_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
a_ptr,
|
a_ptr, b_ptr, c_ptr,
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
M,
|
M, N, K,
|
||||||
N,
|
|
||||||
K,
|
|
||||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||||
# by to get the element one row down (A has M rows)
|
# by to get the element one row down (A has M rows)
|
||||||
stride_am,
|
stride_am, stride_ak,
|
||||||
stride_ak,
|
stride_bk, stride_bn,
|
||||||
stride_bk,
|
stride_cm, stride_cn,
|
||||||
stride_bn,
|
# Meta-parameters
|
||||||
stride_cm,
|
|
||||||
stride_cn,
|
|
||||||
**meta,
|
**meta,
|
||||||
):
|
):
|
||||||
"""Kernel for computing the matmul AB = C
|
"""Kernel for computing the matmul C = A x B.
|
||||||
|
|
||||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
"""
|
"""
|
||||||
# extract meta-parameters
|
# extract meta-parameters
|
||||||
@@ -191,67 +192,65 @@ def matmul_kernel(
|
|||||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||||
GROUP_SIZE_M = 8
|
GROUP_SIZE_M = 8
|
||||||
|
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
|
# This is done in a grouped ordering to promote L2 data reuse
|
||||||
|
# See above `L2 Cache Optimizations` section for details
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
# the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block
|
# ----------------------------------------------------------
|
||||||
# Note that this will lead to some quantization in performance where time-taken jumps
|
# Create pointers for the first blocks of A and B.
|
||||||
# when you need to add a new block
|
# We will advance this pointer as we move in the K direction
|
||||||
n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
# and accumulate
|
||||||
n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||||
|
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||||
|
# see above `Pointer Arithmetics` section for details
|
||||||
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||||
|
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||||
|
|
||||||
# Map PIDs to the block they should compute. This is done in a grouped ordering
|
# -----------------------------------------------------------
|
||||||
# to promote L2 cache reuse.
|
# Iterate to compute a block of the C matrix
|
||||||
n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
group_id = pid // n_output_blocks_in_group
|
# of fp32 values for higher accuracy.
|
||||||
first_m_block_in_group = group_id * GROUP_SIZE_M
|
# `accumulator` will be converted back to fp16 after the loop
|
||||||
|
|
||||||
# If the number of blocks is not divisible by the group size, the last group is smaller
|
|
||||||
group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M)
|
|
||||||
|
|
||||||
# Within a group, we compute in col-major ordering, block_m and block_n are the
|
|
||||||
# output row and col that this program is computing in terms of blocks
|
|
||||||
block_m = first_m_block_in_group + (pid % group_size_m)
|
|
||||||
block_n = (pid % n_output_blocks_in_group) // group_size_m
|
|
||||||
|
|
||||||
# Convert from block indices back to element indices
|
|
||||||
m_start = block_m * BLOCK_SIZE_M
|
|
||||||
n_start = block_n * BLOCK_SIZE_N
|
|
||||||
|
|
||||||
# Expand out to all the offsets for each of the elements in this block.
|
|
||||||
m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
|
||||||
n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
|
||||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
|
|
||||||
# Get the pointers for the first block of each. We will advance this pointer
|
|
||||||
# as we move in the K direction and accumulate.
|
|
||||||
# a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers
|
|
||||||
a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :])
|
|
||||||
# b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers
|
|
||||||
b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b)
|
|
||||||
# We accumulate internally in fp32, but the output is written out in the dtype
|
|
||||||
# of the tensor when it is stored
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, K, BLOCK_SIZE_K):
|
for k in range(0, K, BLOCK_SIZE_K):
|
||||||
# Note that for simplicity, we don't apply a mask here. This means that if K is
|
# Note that for simplicity, we don't apply a mask here.
|
||||||
# not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and
|
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||||
# accumulate it incorrectly.
|
# this will access out-of-bounds memory and produce an
|
||||||
|
# error or (worse!) incorrect results.
|
||||||
a = tl.load(a_ptrs)
|
a = tl.load(a_ptrs)
|
||||||
b = tl.load(b_ptrs)
|
b = tl.load(b_ptrs)
|
||||||
# We accumulate along the K dimension
|
# We accumulate along the K dimension
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
|
|
||||||
# Advance the ptrs to the next K block
|
# Advance the ptrs to the next K block
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
# triton can accept arbitrary activation function via metaparameters!
|
# you can fuse arbitrary activation functions here
|
||||||
if meta['ACTIVATION']:
|
# while the accumulator is still in FP32 !
|
||||||
|
if meta['ACTIVATION']:
|
||||||
accumulator = meta['ACTIVATION'](accumulator)
|
accumulator = meta['ACTIVATION'](accumulator)
|
||||||
|
c = accumulator.to(tl.float16)
|
||||||
|
|
||||||
m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
# -----------------------------------------------------------
|
||||||
n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
# Write back the block of the output matrix C
|
||||||
c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (m_offsets_c < M) & (n_offsets_c < N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
tl.store(c_ptrs, accumulator, mask=mask)
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||||
@@ -282,18 +281,11 @@ def matmul(a, b, activation=None):
|
|||||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||||
)
|
)
|
||||||
matmul_kernel[grid](
|
matmul_kernel[grid](
|
||||||
a,
|
a, b, c,
|
||||||
b,
|
M, N, K,
|
||||||
c,
|
a.stride(0), a.stride(1),
|
||||||
M,
|
b.stride(0), b.stride(1),
|
||||||
N,
|
c.stride(0), c.stride(1),
|
||||||
K,
|
|
||||||
a.stride(0),
|
|
||||||
a.stride(1),
|
|
||||||
b.stride(0),
|
|
||||||
b.stride(1),
|
|
||||||
c.stride(0),
|
|
||||||
c.stride(1),
|
|
||||||
ACTIVATION=activation,
|
ACTIVATION=activation,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 37 KiB After Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
Before Width: | Height: | Size: 55 KiB After Width: | Height: | Size: 54 KiB |
Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 31 KiB |
@@ -234,10 +234,10 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
|||||||
0 4096.0 9.600000 9.600000
|
0 4096.0 9.600000 9.600000
|
||||||
1 8192.0 19.200000 19.200000
|
1 8192.0 19.200000 19.200000
|
||||||
2 16384.0 38.400001 38.400001
|
2 16384.0 38.400001 38.400001
|
||||||
3 32768.0 76.800002 76.800002
|
3 32768.0 63.999998 76.800002
|
||||||
4 65536.0 127.999995 127.999995
|
4 65536.0 127.999995 127.999995
|
||||||
5 131072.0 219.428568 219.428568
|
5 131072.0 219.428568 219.428568
|
||||||
6 262144.0 341.333321 384.000001
|
6 262144.0 384.000001 384.000001
|
||||||
7 524288.0 472.615390 472.615390
|
7 524288.0 472.615390 472.615390
|
||||||
8 1048576.0 614.400016 614.400016
|
8 1048576.0 614.400016 614.400016
|
||||||
9 2097152.0 722.823517 722.823517
|
9 2097152.0 722.823517 722.823517
|
||||||
@@ -254,7 +254,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
|||||||
|
|
||||||
.. rst-class:: sphx-glr-timing
|
.. rst-class:: sphx-glr-timing
|
||||||
|
|
||||||
**Total running time of the script:** ( 0 minutes 10.994 seconds)
|
**Total running time of the script:** ( 0 minutes 10.971 seconds)
|
||||||
|
|
||||||
|
|
||||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||||
|
@@ -306,10 +306,10 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
|||||||
3 640.0 682.666684 640.000002 160.000000
|
3 640.0 682.666684 640.000002 160.000000
|
||||||
4 768.0 702.171410 664.216187 163.839992
|
4 768.0 702.171410 664.216187 163.839992
|
||||||
.. ... ... ... ...
|
.. ... ... ... ...
|
||||||
93 12160.0 812.359066 406.179533 198.936606
|
93 12160.0 812.359066 405.755985 198.936606
|
||||||
94 12288.0 812.429770 416.101597 199.298541
|
94 12288.0 812.429770 415.222812 199.096718
|
||||||
95 12416.0 810.840807 412.149375 198.854847
|
95 12416.0 810.840807 411.296057 198.755369
|
||||||
96 12544.0 810.925276 412.971190 199.209928
|
96 12544.0 810.925276 412.971190 199.012395
|
||||||
97 12672.0 811.007961 412.097543 199.167004
|
97 12672.0 811.007961 412.097543 199.167004
|
||||||
|
|
||||||
[98 rows x 4 columns]
|
[98 rows x 4 columns]
|
||||||
@@ -328,7 +328,7 @@ In the above plot, we can see that:
|
|||||||
|
|
||||||
.. rst-class:: sphx-glr-timing
|
.. rst-class:: sphx-glr-timing
|
||||||
|
|
||||||
**Total running time of the script:** ( 1 minutes 12.617 seconds)
|
**Total running time of the script:** ( 1 minutes 12.739 seconds)
|
||||||
|
|
||||||
|
|
||||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||||
|
@@ -42,7 +42,7 @@ In this tutorial, you will learn how to implement efficient matrix multiplicatio
|
|||||||
yourself with Triton, in a way that is easy to customize and extend.
|
yourself with Triton, in a way that is easy to customize and extend.
|
||||||
|
|
||||||
Roughly speaking, the kernel that we will write will implement the following blocked
|
Roughly speaking, the kernel that we will write will implement the following blocked
|
||||||
algorithm to multiply a (MxK) by a (KxN) matrix:
|
algorithm to multiply a (M, K) by a (K, N) matrix:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -57,9 +57,9 @@ algorithm to multiply a (MxK) by a (KxN) matrix:
|
|||||||
acc += dot(a, b)
|
acc += dot(a, b)
|
||||||
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||||
|
|
||||||
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 44-129
|
.. GENERATED FROM PYTHON SOURCE LINES 44-137
|
||||||
|
|
||||||
Compute Kernel
|
Compute Kernel
|
||||||
----------------
|
----------------
|
||||||
@@ -73,35 +73,31 @@ Pointer Arithmetics
|
|||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||||
y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
|
||||||
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||||
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
:code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||||
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||||
|
|
||||||
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
pid_m = triton.program_id(0)
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
pid_n = triton.program_id(1)
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)
|
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||||
rk = triton.arange(0, BLOCK_SIZE_K)
|
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||||
// pointer for A operand
|
|
||||||
pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
|
||||||
// pointer for B operand
|
|
||||||
pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
|
||||||
|
|
||||||
And then updated in the inner loop as follows:
|
And then updated in the inner loop as follows:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
pa += BLOCK_SIZE_K * stride_a_1;
|
pa += BLOCK_SIZE_K * stride_ak;
|
||||||
pb += BLOCK_SIZE_K * stride_b_0;
|
pb += BLOCK_SIZE_K * stride_bk;
|
||||||
|
|
||||||
|
|
||||||
L2 Cache Optimizations
|
L2 Cache Optimizations
|
||||||
@@ -129,13 +125,25 @@ switching to the next column:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
pid = triton.program_id(0);
|
# program ID
|
||||||
width = GROUP_M * grid_n;
|
pid = tl.program_id(axis=0)
|
||||||
group_id = pid // width;
|
# number of program ids along the M axis
|
||||||
# we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
# number of programs ids along the N axis
|
||||||
pid_m = group_id * GROUP_M + (pid % group_size);
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
pid_n = (pid % width) // (group_size);
|
# number of programs in group
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
# id of the group this program is in
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
# row-id of the first program in the group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
# *within groups*, programs are ordered in a column-major order
|
||||||
|
# row-id of the program in the *launch grid*
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
# col-id of the program in the *launch grid*
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||||
we can see that if we compute the output in row-major ordering, we need to load 90
|
we can see that if we compute the output in row-major ordering, we need to load 90
|
||||||
@@ -147,13 +155,13 @@ In practice, this can improve the performance of our matrix multiplication kerne
|
|||||||
more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
||||||
|
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 131-134
|
.. GENERATED FROM PYTHON SOURCE LINES 139-142
|
||||||
|
|
||||||
Final Result
|
Final Result
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 134-263
|
.. GENERATED FROM PYTHON SOURCE LINES 142-262
|
||||||
|
|
||||||
.. code-block:: default
|
.. code-block:: default
|
||||||
|
|
||||||
@@ -190,26 +198,19 @@ Final Result
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def matmul_kernel(
|
def matmul_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
a_ptr,
|
a_ptr, b_ptr, c_ptr,
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
M,
|
M, N, K,
|
||||||
N,
|
|
||||||
K,
|
|
||||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||||
# by to get the element one row down (A has M rows)
|
# by to get the element one row down (A has M rows)
|
||||||
stride_am,
|
stride_am, stride_ak,
|
||||||
stride_ak,
|
stride_bk, stride_bn,
|
||||||
stride_bk,
|
stride_cm, stride_cn,
|
||||||
stride_bn,
|
# Meta-parameters
|
||||||
stride_cm,
|
|
||||||
stride_cn,
|
|
||||||
**meta,
|
**meta,
|
||||||
):
|
):
|
||||||
"""Kernel for computing the matmul AB = C
|
"""Kernel for computing the matmul C = A x B.
|
||||||
|
|
||||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
"""
|
"""
|
||||||
# extract meta-parameters
|
# extract meta-parameters
|
||||||
@@ -217,67 +218,65 @@ Final Result
|
|||||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||||
GROUP_SIZE_M = 8
|
GROUP_SIZE_M = 8
|
||||||
|
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
|
# This is done in a grouped ordering to promote L2 data reuse
|
||||||
|
# See above `L2 Cache Optimizations` section for details
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
# the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block
|
# ----------------------------------------------------------
|
||||||
# Note that this will lead to some quantization in performance where time-taken jumps
|
# Create pointers for the first blocks of A and B.
|
||||||
# when you need to add a new block
|
# We will advance this pointer as we move in the K direction
|
||||||
n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
# and accumulate
|
||||||
n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||||
|
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||||
|
# see above `Pointer Arithmetics` section for details
|
||||||
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||||
|
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||||
|
|
||||||
# Map PIDs to the block they should compute. This is done in a grouped ordering
|
# -----------------------------------------------------------
|
||||||
# to promote L2 cache reuse.
|
# Iterate to compute a block of the C matrix
|
||||||
n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
group_id = pid // n_output_blocks_in_group
|
# of fp32 values for higher accuracy.
|
||||||
first_m_block_in_group = group_id * GROUP_SIZE_M
|
# `accumulator` will be converted back to fp16 after the loop
|
||||||
|
|
||||||
# If the number of blocks is not divisible by the group size, the last group is smaller
|
|
||||||
group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M)
|
|
||||||
|
|
||||||
# Within a group, we compute in col-major ordering, block_m and block_n are the
|
|
||||||
# output row and col that this program is computing in terms of blocks
|
|
||||||
block_m = first_m_block_in_group + (pid % group_size_m)
|
|
||||||
block_n = (pid % n_output_blocks_in_group) // group_size_m
|
|
||||||
|
|
||||||
# Convert from block indices back to element indices
|
|
||||||
m_start = block_m * BLOCK_SIZE_M
|
|
||||||
n_start = block_n * BLOCK_SIZE_N
|
|
||||||
|
|
||||||
# Expand out to all the offsets for each of the elements in this block.
|
|
||||||
m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
|
||||||
n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
|
||||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
|
|
||||||
# Get the pointers for the first block of each. We will advance this pointer
|
|
||||||
# as we move in the K direction and accumulate.
|
|
||||||
# a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers
|
|
||||||
a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :])
|
|
||||||
# b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers
|
|
||||||
b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b)
|
|
||||||
# We accumulate internally in fp32, but the output is written out in the dtype
|
|
||||||
# of the tensor when it is stored
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, K, BLOCK_SIZE_K):
|
for k in range(0, K, BLOCK_SIZE_K):
|
||||||
# Note that for simplicity, we don't apply a mask here. This means that if K is
|
# Note that for simplicity, we don't apply a mask here.
|
||||||
# not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and
|
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||||
# accumulate it incorrectly.
|
# this will access out-of-bounds memory and produce an
|
||||||
|
# error or (worse!) incorrect results.
|
||||||
a = tl.load(a_ptrs)
|
a = tl.load(a_ptrs)
|
||||||
b = tl.load(b_ptrs)
|
b = tl.load(b_ptrs)
|
||||||
# We accumulate along the K dimension
|
# We accumulate along the K dimension
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
|
|
||||||
# Advance the ptrs to the next K block
|
# Advance the ptrs to the next K block
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
# triton can accept arbitrary activation function via metaparameters!
|
# you can fuse arbitrary activation functions here
|
||||||
if meta['ACTIVATION']:
|
# while the accumulator is still in FP32 !
|
||||||
|
if meta['ACTIVATION']:
|
||||||
accumulator = meta['ACTIVATION'](accumulator)
|
accumulator = meta['ACTIVATION'](accumulator)
|
||||||
|
c = accumulator.to(tl.float16)
|
||||||
|
|
||||||
m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
# -----------------------------------------------------------
|
||||||
n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
# Write back the block of the output matrix C
|
||||||
c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (m_offsets_c < M) & (n_offsets_c < N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
tl.store(c_ptrs, accumulator, mask=mask)
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||||
@@ -293,12 +292,12 @@ Final Result
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 264-266
|
.. GENERATED FROM PYTHON SOURCE LINES 263-265
|
||||||
|
|
||||||
We can now create a convenience wrapper function that only takes two input tensors
|
We can now create a convenience wrapper function that only takes two input tensors
|
||||||
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 266-302
|
.. GENERATED FROM PYTHON SOURCE LINES 265-294
|
||||||
|
|
||||||
.. code-block:: default
|
.. code-block:: default
|
||||||
|
|
||||||
@@ -321,18 +320,11 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
|||||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||||
)
|
)
|
||||||
matmul_kernel[grid](
|
matmul_kernel[grid](
|
||||||
a,
|
a, b, c,
|
||||||
b,
|
M, N, K,
|
||||||
c,
|
a.stride(0), a.stride(1),
|
||||||
M,
|
b.stride(0), b.stride(1),
|
||||||
N,
|
c.stride(0), c.stride(1),
|
||||||
K,
|
|
||||||
a.stride(0),
|
|
||||||
a.stride(1),
|
|
||||||
b.stride(0),
|
|
||||||
b.stride(1),
|
|
||||||
c.stride(0),
|
|
||||||
c.stride(1),
|
|
||||||
ACTIVATION=activation,
|
ACTIVATION=activation,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
@@ -345,14 +337,14 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 303-307
|
.. GENERATED FROM PYTHON SOURCE LINES 295-299
|
||||||
|
|
||||||
Unit Test
|
Unit Test
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 307-320
|
.. GENERATED FROM PYTHON SOURCE LINES 299-312
|
||||||
|
|
||||||
.. code-block:: default
|
.. code-block:: default
|
||||||
|
|
||||||
@@ -400,7 +392,7 @@ We can test our custom matrix multiplication operation against a native torch im
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 321-327
|
.. GENERATED FROM PYTHON SOURCE LINES 313-319
|
||||||
|
|
||||||
Benchmark
|
Benchmark
|
||||||
--------------
|
--------------
|
||||||
@@ -409,7 +401,7 @@ Square Matrix Performance
|
|||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
||||||
|
|
||||||
.. GENERATED FROM PYTHON SOURCE LINES 327-368
|
.. GENERATED FROM PYTHON SOURCE LINES 319-360
|
||||||
|
|
||||||
.. code-block:: default
|
.. code-block:: default
|
||||||
|
|
||||||
@@ -471,37 +463,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
|||||||
matmul-performance:
|
matmul-performance:
|
||||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||||
0 128.0 0.455111 ... 0.512000 0.512000
|
0 128.0 0.455111 ... 0.512000 0.512000
|
||||||
1 256.0 2.730667 ... 2.978909 2.978909
|
1 256.0 2.978909 ... 2.978909 2.978909
|
||||||
2 384.0 7.372800 ... 8.507077 8.507077
|
2 384.0 7.372800 ... 8.507077 8.507077
|
||||||
3 512.0 14.563555 ... 15.420235 16.384000
|
3 512.0 14.563555 ... 16.384000 16.384000
|
||||||
4 640.0 22.260869 ... 24.380953 23.272727
|
4 640.0 22.260869 ... 24.380953 24.380953
|
||||||
5 768.0 32.768000 ... 34.028308 34.028308
|
5 768.0 32.768000 ... 34.028308 34.028308
|
||||||
6 896.0 39.025776 ... 40.140799 39.025776
|
6 896.0 39.025776 ... 40.140799 36.023796
|
||||||
7 1024.0 49.932191 ... 53.773130 52.428801
|
7 1024.0 49.932191 ... 52.428801 52.428801
|
||||||
8 1152.0 45.242181 ... 46.656000 46.656000
|
8 1152.0 44.566925 ... 46.656000 46.656000
|
||||||
9 1280.0 51.200001 ... 56.888887 56.888887
|
9 1280.0 51.200001 ... 56.888887 56.109587
|
||||||
10 1408.0 64.138541 ... 64.902096 64.902096
|
10 1408.0 64.138541 ... 64.902096 64.902096
|
||||||
11 1536.0 78.643199 ... 76.106321 75.296679
|
11 1536.0 78.643199 ... 76.106321 75.296679
|
||||||
12 1664.0 62.929456 ... 62.061463 62.061463
|
12 1664.0 63.372618 ... 62.492442 61.636381
|
||||||
13 1792.0 72.983276 ... 69.810085 69.379162
|
13 1792.0 72.983276 ... 69.810085 69.379162
|
||||||
14 1920.0 67.434145 ... 70.892307 70.530615
|
14 1920.0 67.434145 ... 70.892307 70.530615
|
||||||
15 2048.0 73.908442 ... 74.898285 74.565406
|
15 2048.0 73.908442 ... 75.234154 74.898285
|
||||||
16 2176.0 83.500614 ... 78.916269 79.855747
|
16 2176.0 81.472263 ... 80.817862 80.173899
|
||||||
17 2304.0 68.251065 ... 73.275679 72.828879
|
17 2304.0 68.446623 ... 73.501144 73.275679
|
||||||
18 2432.0 71.125224 ... 80.731218 80.731218
|
18 2432.0 71.305746 ... 81.197876 79.362895
|
||||||
19 2560.0 77.649287 ... 76.560748 76.382283
|
19 2560.0 77.649287 ... 77.649287 76.560748
|
||||||
20 2688.0 81.928846 ... 80.366642 82.823267
|
20 2688.0 82.642823 ... 80.708630 82.823267
|
||||||
21 2816.0 77.743683 ... 78.868366 78.301990
|
21 2816.0 79.587973 ... 79.733474 77.605356
|
||||||
22 2944.0 81.832567 ... 79.610276 78.605729
|
22 2944.0 81.967162 ... 78.112900 79.230573
|
||||||
23 3072.0 81.005868 ... 81.005868 82.420822
|
23 3072.0 81.707223 ... 84.135370 79.863336
|
||||||
24 3200.0 84.321474 ... 89.635851 85.106381
|
24 3200.0 84.099871 ... 87.074829 89.136491
|
||||||
25 3328.0 83.226931 ... 87.156532 86.113988
|
25 3328.0 83.905938 ... 84.003845 86.424125
|
||||||
26 3456.0 81.932484 ... 83.632331 85.313831
|
26 3456.0 81.518272 ... 85.494768 81.353753
|
||||||
27 3584.0 87.211821 ... 87.211821 91.563533
|
27 3584.0 86.540320 ... 94.448944 94.847460
|
||||||
28 3712.0 85.896254 ... 82.491612 84.874549
|
28 3712.0 83.947349 ... 88.955779 89.114488
|
||||||
29 3840.0 85.070769 ... 87.493673 87.701820
|
29 3840.0 84.809814 ... 88.191387 87.217666
|
||||||
30 3968.0 92.935215 ... 83.865247 83.578035
|
30 3968.0 93.148045 ... 83.179234 87.409694
|
||||||
31 4096.0 93.662059 ... 85.926841 84.840533
|
31 4096.0 93.531519 ... 89.777746 87.552332
|
||||||
|
|
||||||
[32 rows x 5 columns]
|
[32 rows x 5 columns]
|
||||||
|
|
||||||
@@ -511,7 +503,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
|||||||
|
|
||||||
.. rst-class:: sphx-glr-timing
|
.. rst-class:: sphx-glr-timing
|
||||||
|
|
||||||
**Total running time of the script:** ( 2 minutes 9.226 seconds)
|
**Total running time of the script:** ( 2 minutes 30.498 seconds)
|
||||||
|
|
||||||
|
|
||||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||||
|
@@ -5,12 +5,12 @@
|
|||||||
|
|
||||||
Computation times
|
Computation times
|
||||||
=================
|
=================
|
||||||
**03:32.837** total execution time for **getting-started_tutorials** files:
|
**03:54.208** total execution time for **getting-started_tutorials** files:
|
||||||
|
|
||||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:09.226 | 0.0 MB |
|
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 02:30.498 | 0.0 MB |
|
||||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.617 | 0.0 MB |
|
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 01:12.739 | 0.0 MB |
|
||||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.994 | 0.0 MB |
|
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 00:10.971 | 0.0 MB |
|
||||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||||
|
@@ -322,10 +322,10 @@ for different problem sizes.</p>
|
|||||||
0 4096.0 9.600000 9.600000
|
0 4096.0 9.600000 9.600000
|
||||||
1 8192.0 19.200000 19.200000
|
1 8192.0 19.200000 19.200000
|
||||||
2 16384.0 38.400001 38.400001
|
2 16384.0 38.400001 38.400001
|
||||||
3 32768.0 76.800002 76.800002
|
3 32768.0 63.999998 76.800002
|
||||||
4 65536.0 127.999995 127.999995
|
4 65536.0 127.999995 127.999995
|
||||||
5 131072.0 219.428568 219.428568
|
5 131072.0 219.428568 219.428568
|
||||||
6 262144.0 341.333321 384.000001
|
6 262144.0 384.000001 384.000001
|
||||||
7 524288.0 472.615390 472.615390
|
7 524288.0 472.615390 472.615390
|
||||||
8 1048576.0 614.400016 614.400016
|
8 1048576.0 614.400016 614.400016
|
||||||
9 2097152.0 722.823517 722.823517
|
9 2097152.0 722.823517 722.823517
|
||||||
@@ -337,7 +337,7 @@ for different problem sizes.</p>
|
|||||||
15 134217728.0 851.577704 850.656574
|
15 134217728.0 851.577704 850.656574
|
||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.994 seconds)</p>
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 0 minutes 10.971 seconds)</p>
|
||||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-01-vector-add-py">
|
||||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||||
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
<p><a class="reference download internal" download="" href="../../_downloads/62d97d49a32414049819dd8bb8378080/01-vector-add.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">01-vector-add.py</span></code></a></p>
|
||||||
|
@@ -391,10 +391,10 @@ We will then compare its performance against (1) <code class="code docutils lite
|
|||||||
3 640.0 682.666684 640.000002 160.000000
|
3 640.0 682.666684 640.000002 160.000000
|
||||||
4 768.0 702.171410 664.216187 163.839992
|
4 768.0 702.171410 664.216187 163.839992
|
||||||
.. ... ... ... ...
|
.. ... ... ... ...
|
||||||
93 12160.0 812.359066 406.179533 198.936606
|
93 12160.0 812.359066 405.755985 198.936606
|
||||||
94 12288.0 812.429770 416.101597 199.298541
|
94 12288.0 812.429770 415.222812 199.096718
|
||||||
95 12416.0 810.840807 412.149375 198.854847
|
95 12416.0 810.840807 411.296057 198.755369
|
||||||
96 12544.0 810.925276 412.971190 199.209928
|
96 12544.0 810.925276 412.971190 199.012395
|
||||||
97 12672.0 811.007961 412.097543 199.167004
|
97 12672.0 811.007961 412.097543 199.167004
|
||||||
|
|
||||||
[98 rows x 4 columns]
|
[98 rows x 4 columns]
|
||||||
@@ -408,7 +408,7 @@ We will then compare its performance against (1) <code class="code docutils lite
|
|||||||
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
|
Note however that the PyTorch <cite>softmax</cite> operation is more general and will works on tensors of any shape.</p></li>
|
||||||
</ul>
|
</ul>
|
||||||
</div></blockquote>
|
</div></blockquote>
|
||||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.617 seconds)</p>
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 1 minutes 12.739 seconds)</p>
|
||||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-02-fused-softmax-py">
|
||||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||||
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
<p><a class="reference download internal" download="" href="../../_downloads/d91442ac2982c4e0cc3ab0f43534afbc/02-fused-softmax.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">02-fused-softmax.py</span></code></a></p>
|
||||||
|
@@ -221,7 +221,7 @@ to accomodate the needs of modern deep learning workloads (e.g., fused activatio
|
|||||||
In this tutorial, you will learn how to implement efficient matrix multiplications by
|
In this tutorial, you will learn how to implement efficient matrix multiplications by
|
||||||
yourself with Triton, in a way that is easy to customize and extend.</p>
|
yourself with Triton, in a way that is easy to customize and extend.</p>
|
||||||
<p>Roughly speaking, the kernel that we will write will implement the following blocked
|
<p>Roughly speaking, the kernel that we will write will implement the following blocked
|
||||||
algorithm to multiply a (MxK) by a (KxN) matrix:</p>
|
algorithm to multiply a (M, K) by a (K, N) matrix:</p>
|
||||||
<blockquote>
|
<blockquote>
|
||||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># do in parallel</span>
|
||||||
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
|
<span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">):</span>
|
||||||
@@ -236,7 +236,7 @@ algorithm to multiply a (MxK) by a (KxN) matrix:</p>
|
|||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div></blockquote>
|
</div></blockquote>
|
||||||
<p>where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.</p>
|
<p>where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="section" id="compute-kernel">
|
<div class="section" id="compute-kernel">
|
||||||
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
<h2>Compute Kernel<a class="headerlink" href="#compute-kernel" title="Permalink to this headline">¶</a></h2>
|
||||||
@@ -247,33 +247,29 @@ multi-dimensional pointer arithmetics.</p>
|
|||||||
<div class="section" id="pointer-arithmetics">
|
<div class="section" id="pointer-arithmetics">
|
||||||
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
<h3>Pointer Arithmetics<a class="headerlink" href="#pointer-arithmetics" title="Permalink to this headline">¶</a></h3>
|
||||||
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
|
<p>For a row-major 2D tensor <code class="code docutils literal notranslate"><span class="pre">X</span></code>, the memory location of <code class="code docutils literal notranslate"><span class="pre">X[i,</span> <span class="pre">j]</span></code> is given b
|
||||||
y <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_x_0</span> <span class="pre">+</span> <span class="pre">j*stride_x_1</span></code>.
|
y <code class="code docutils literal notranslate"><span class="pre">&X[i,</span> <span class="pre">j]</span> <span class="pre">=</span> <span class="pre">X</span> <span class="pre">+</span> <span class="pre">i*stride_xi</span> <span class="pre">+</span> <span class="pre">j*stride_xj</span></code>.
|
||||||
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_SIZE_M,</span> <span class="pre">k:k+BLOCK_SIZE_K]</span></code> and
|
Therefore, blocks of pointers for <code class="code docutils literal notranslate"><span class="pre">A[m</span> <span class="pre">:</span> <span class="pre">m+BLOCK_SIZE_M,</span> <span class="pre">k:k+BLOCK_SIZE_K]</span></code> and
|
||||||
<code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_SIZE_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_SIZE_N]</span></code> can be defined in pseudo-code as:</p>
|
<code class="code docutils literal notranslate"><span class="pre">B[k</span> <span class="pre">:</span> <span class="pre">k+BLOCK_SIZE_K,</span> <span class="pre">n</span> <span class="pre">:</span> <span class="pre">n+BLOCK_SIZE_N]</span></code> can be defined in pseudo-code as:</p>
|
||||||
<blockquote>
|
<blockquote>
|
||||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">&</span><span class="n">A</span><span class="p">[</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span><span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">]</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">m</span> <span class="p">:</span> <span class="n">m</span><span class="o">+</span><span class="n">BLOCK_SIZE_M</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">A</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||||
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
<span class="o">&</span><span class="n">B</span><span class="p">[</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">,</span> <span class="n">n</span><span class="p">:</span><span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">]</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">k</span> <span class="p">:</span> <span class="n">k</span><span class="o">+</span><span class="n">BLOCK_SIZE_K</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="p">:</span> <span class="n">n</span><span class="o">+</span><span class="n">BLOCK_SIZE_N</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">B</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
|
||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div></blockquote>
|
</div></blockquote>
|
||||||
<p>Which means that pointers for blocks of A and B can be initialized (i.e., <code class="code docutils literal notranslate"><span class="pre">k=0</span></code>) in Triton as:</p>
|
<p>Which means that pointers for blocks of A and B can be initialized (i.e., <code class="code docutils literal notranslate"><span class="pre">k=0</span></code>) in Triton as:</p>
|
||||||
<blockquote>
|
<blockquote>
|
||||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid_m</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||||
<span class="n">rm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
||||||
<span class="n">rn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span> <span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_ak</span><span class="p">)</span>
|
||||||
<span class="n">rk</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span> <span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_bn</span><span class="p">)</span>
|
||||||
<span class="o">//</span> <span class="n">pointer</span> <span class="k">for</span> <span class="n">A</span> <span class="n">operand</span>
|
|
||||||
<span class="n">pa</span> <span class="o">=</span> <span class="n">A</span> <span class="o">+</span> <span class="p">(</span><span class="n">rm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_a_0</span> <span class="o">+</span> <span class="n">rk</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">);</span>
|
|
||||||
<span class="o">//</span> <span class="n">pointer</span> <span class="k">for</span> <span class="n">B</span> <span class="n">operand</span>
|
|
||||||
<span class="n">pb</span> <span class="o">=</span> <span class="n">B</span> <span class="o">+</span> <span class="p">(</span><span class="n">rk</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">stride_b_0</span> <span class="o">+</span> <span class="n">rn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">stride_b_1</span><span class="p">);</span>
|
|
||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div></blockquote>
|
</div></blockquote>
|
||||||
<p>And then updated in the inner loop as follows:</p>
|
<p>And then updated in the inner loop as follows:</p>
|
||||||
<blockquote>
|
<blockquote>
|
||||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_a_1</span><span class="p">;</span>
|
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pa</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span><span class="p">;</span>
|
||||||
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_b_0</span><span class="p">;</span>
|
<span class="n">pb</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span><span class="p">;</span>
|
||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div></blockquote>
|
</div></blockquote>
|
||||||
@@ -299,13 +295,25 @@ a simple row-major ordering</p>
|
|||||||
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before
|
This can be done by ‘super-grouping’ blocks in groups of <code class="code docutils literal notranslate"><span class="pre">GROUP_M</span></code> rows before
|
||||||
switching to the next column:</p>
|
switching to the next column:</p>
|
||||||
<blockquote>
|
<blockquote>
|
||||||
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pid</span> <span class="o">=</span> <span class="n">triton</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="mi">0</span><span class="p">);</span>
|
<div><div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># program ID</span>
|
||||||
<span class="n">width</span> <span class="o">=</span> <span class="n">GROUP_M</span> <span class="o">*</span> <span class="n">grid_n</span><span class="p">;</span>
|
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">width</span><span class="p">;</span>
|
<span class="c1"># number of program ids along the M axis</span>
|
||||||
<span class="c1"># we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0</span>
|
<span class="n">num_pid_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">grid_m</span> <span class="o">-</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span><span class="p">,</span> <span class="n">GROUP_M</span><span class="p">);</span>
|
<span class="c1"># number of programs ids along the N axis</span>
|
||||||
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_M</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size</span><span class="p">);</span>
|
<span class="n">num_pid_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||||
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">width</span><span class="p">)</span> <span class="o">//</span> <span class="p">(</span><span class="n">group_size</span><span class="p">);</span>
|
<span class="c1"># number of programs in group</span>
|
||||||
|
<span class="n">num_pid_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">num_pid_n</span>
|
||||||
|
<span class="c1"># id of the group this program is in</span>
|
||||||
|
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">num_pid_in_group</span>
|
||||||
|
<span class="c1"># row-id of the first program in the group</span>
|
||||||
|
<span class="n">first_pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
|
||||||
|
<span class="c1"># if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller</span>
|
||||||
|
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_pid_m</span> <span class="o">-</span> <span class="n">first_pid_m</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
|
||||||
|
<span class="c1"># *within groups*, programs are ordered in a column-major order</span>
|
||||||
|
<span class="c1"># row-id of the program in the *launch grid*</span>
|
||||||
|
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">first_pid_m</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
|
||||||
|
<span class="c1"># col-id of the program in the *launch grid*</span>
|
||||||
|
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">num_pid_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
|
||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
</div></blockquote>
|
</div></blockquote>
|
||||||
@@ -354,26 +362,19 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
|
|||||||
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
<span class="nd">@triton</span><span class="o">.</span><span class="n">jit</span>
|
||||||
<span class="k">def</span> <span class="nf">matmul_kernel</span><span class="p">(</span>
|
<span class="k">def</span> <span class="nf">matmul_kernel</span><span class="p">(</span>
|
||||||
<span class="c1"># Pointers to matrices</span>
|
<span class="c1"># Pointers to matrices</span>
|
||||||
<span class="n">a_ptr</span><span class="p">,</span>
|
<span class="n">a_ptr</span><span class="p">,</span> <span class="n">b_ptr</span><span class="p">,</span> <span class="n">c_ptr</span><span class="p">,</span>
|
||||||
<span class="n">b_ptr</span><span class="p">,</span>
|
|
||||||
<span class="n">c_ptr</span><span class="p">,</span>
|
|
||||||
<span class="c1"># Matrix dimensions</span>
|
<span class="c1"># Matrix dimensions</span>
|
||||||
<span class="n">M</span><span class="p">,</span>
|
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span>
|
||||||
<span class="n">N</span><span class="p">,</span>
|
|
||||||
<span class="n">K</span><span class="p">,</span>
|
|
||||||
<span class="c1"># The stride variables represent how much to increase the ptr by when moving by 1</span>
|
<span class="c1"># The stride variables represent how much to increase the ptr by when moving by 1</span>
|
||||||
<span class="c1"># element in a particular dimension. E.g. stride_am is how much to increase a_ptr</span>
|
<span class="c1"># element in a particular dimension. E.g. stride_am is how much to increase a_ptr</span>
|
||||||
<span class="c1"># by to get the element one row down (A has M rows)</span>
|
<span class="c1"># by to get the element one row down (A has M rows)</span>
|
||||||
<span class="n">stride_am</span><span class="p">,</span>
|
<span class="n">stride_am</span><span class="p">,</span> <span class="n">stride_ak</span><span class="p">,</span>
|
||||||
<span class="n">stride_ak</span><span class="p">,</span>
|
<span class="n">stride_bk</span><span class="p">,</span> <span class="n">stride_bn</span><span class="p">,</span>
|
||||||
<span class="n">stride_bk</span><span class="p">,</span>
|
<span class="n">stride_cm</span><span class="p">,</span> <span class="n">stride_cn</span><span class="p">,</span>
|
||||||
<span class="n">stride_bn</span><span class="p">,</span>
|
<span class="c1"># Meta-parameters</span>
|
||||||
<span class="n">stride_cm</span><span class="p">,</span>
|
|
||||||
<span class="n">stride_cn</span><span class="p">,</span>
|
|
||||||
<span class="o">**</span><span class="n">meta</span><span class="p">,</span>
|
<span class="o">**</span><span class="n">meta</span><span class="p">,</span>
|
||||||
<span class="p">):</span>
|
<span class="p">):</span>
|
||||||
<span class="sd">"""Kernel for computing the matmul AB = C</span>
|
<span class="sd">"""Kernel for computing the matmul C = A x B.</span>
|
||||||
|
|
||||||
<span class="sd"> A has shape (M, K), B has shape (K, N) and C has shape (M, N)</span>
|
<span class="sd"> A has shape (M, K), B has shape (K, N) and C has shape (M, N)</span>
|
||||||
<span class="sd"> """</span>
|
<span class="sd"> """</span>
|
||||||
<span class="c1"># extract meta-parameters</span>
|
<span class="c1"># extract meta-parameters</span>
|
||||||
@@ -381,67 +382,65 @@ more than 10% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).</
|
|||||||
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]</span>
|
<span class="n">BLOCK_SIZE_N</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]</span>
|
||||||
<span class="n">BLOCK_SIZE_K</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_K'</span><span class="p">]</span>
|
<span class="n">BLOCK_SIZE_K</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_K'</span><span class="p">]</span>
|
||||||
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">8</span>
|
<span class="n">GROUP_SIZE_M</span> <span class="o">=</span> <span class="mi">8</span>
|
||||||
|
|
||||||
|
<span class="c1"># -----------------------------------------------------------</span>
|
||||||
|
<span class="c1"># Map program ids `pid` to the block of C it should compute.</span>
|
||||||
|
<span class="c1"># This is done in a grouped ordering to promote L2 data reuse</span>
|
||||||
|
<span class="c1"># See above `L2 Cache Optimizations` section for details</span>
|
||||||
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
<span class="n">pid</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">program_id</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||||
|
<span class="n">num_pid_m</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||||
|
<span class="n">num_pid_n</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||||
|
<span class="n">num_pid_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">num_pid_n</span>
|
||||||
|
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">num_pid_in_group</span>
|
||||||
|
<span class="n">first_pid_m</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
|
||||||
|
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_pid_m</span> <span class="o">-</span> <span class="n">first_pid_m</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
|
||||||
|
<span class="n">pid_m</span> <span class="o">=</span> <span class="n">first_pid_m</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
|
||||||
|
<span class="n">pid_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">num_pid_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
|
||||||
|
|
||||||
<span class="c1"># the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block</span>
|
<span class="c1"># ----------------------------------------------------------</span>
|
||||||
<span class="c1"># Note that this will lead to some quantization in performance where time-taken jumps</span>
|
<span class="c1"># Create pointers for the first blocks of A and B.</span>
|
||||||
<span class="c1"># when you need to add a new block</span>
|
<span class="c1"># We will advance this pointer as we move in the K direction</span>
|
||||||
<span class="n">n_blocks_m</span> <span class="o">=</span> <span class="p">(</span><span class="n">M</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_M</span>
|
<span class="c1"># and accumulate</span>
|
||||||
<span class="n">n_blocks_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="n">BLOCK_SIZE_N</span>
|
<span class="c1"># a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers</span>
|
||||||
|
<span class="c1"># b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers</span>
|
||||||
|
<span class="c1"># see above `Pointer Arithmetics` section for details</span>
|
||||||
|
<span class="n">offs_am</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||||
|
<span class="n">offs_bn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||||
|
<span class="n">offs_k</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
||||||
|
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_am</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_am</span> <span class="o">+</span> <span class="n">offs_k</span> <span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_ak</span><span class="p">)</span>
|
||||||
|
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">offs_k</span> <span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">*</span><span class="n">stride_bk</span> <span class="o">+</span> <span class="n">offs_bn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span><span class="o">*</span><span class="n">stride_bn</span><span class="p">)</span>
|
||||||
|
|
||||||
<span class="c1"># Map PIDs to the block they should compute. This is done in a grouped ordering</span>
|
<span class="c1"># -----------------------------------------------------------</span>
|
||||||
<span class="c1"># to promote L2 cache reuse.</span>
|
<span class="c1"># Iterate to compute a block of the C matrix</span>
|
||||||
<span class="n">n_output_blocks_in_group</span> <span class="o">=</span> <span class="n">GROUP_SIZE_M</span> <span class="o">*</span> <span class="n">n_blocks_n</span>
|
<span class="c1"># We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block</span>
|
||||||
<span class="n">group_id</span> <span class="o">=</span> <span class="n">pid</span> <span class="o">//</span> <span class="n">n_output_blocks_in_group</span>
|
<span class="c1"># of fp32 values for higher accuracy.</span>
|
||||||
<span class="n">first_m_block_in_group</span> <span class="o">=</span> <span class="n">group_id</span> <span class="o">*</span> <span class="n">GROUP_SIZE_M</span>
|
<span class="c1"># `accumulator` will be converted back to fp16 after the loop</span>
|
||||||
|
|
||||||
<span class="c1"># If the number of blocks is not divisible by the group size, the last group is smaller</span>
|
|
||||||
<span class="n">group_size_m</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">n_blocks_m</span> <span class="o">-</span> <span class="n">first_m_block_in_group</span><span class="p">,</span> <span class="n">GROUP_SIZE_M</span><span class="p">)</span>
|
|
||||||
|
|
||||||
<span class="c1"># Within a group, we compute in col-major ordering, block_m and block_n are the</span>
|
|
||||||
<span class="c1"># output row and col that this program is computing in terms of blocks</span>
|
|
||||||
<span class="n">block_m</span> <span class="o">=</span> <span class="n">first_m_block_in_group</span> <span class="o">+</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">group_size_m</span><span class="p">)</span>
|
|
||||||
<span class="n">block_n</span> <span class="o">=</span> <span class="p">(</span><span class="n">pid</span> <span class="o">%</span> <span class="n">n_output_blocks_in_group</span><span class="p">)</span> <span class="o">//</span> <span class="n">group_size_m</span>
|
|
||||||
|
|
||||||
<span class="c1"># Convert from block indices back to element indices</span>
|
|
||||||
<span class="n">m_start</span> <span class="o">=</span> <span class="n">block_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span>
|
|
||||||
<span class="n">n_start</span> <span class="o">=</span> <span class="n">block_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span>
|
|
||||||
|
|
||||||
<span class="c1"># Expand out to all the offsets for each of the elements in this block.</span>
|
|
||||||
<span class="n">m_offsets_a</span> <span class="o">=</span> <span class="p">(</span><span class="n">m_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">))[:,</span> <span class="kc">None</span><span class="p">]</span>
|
|
||||||
<span class="n">n_offsets_b</span> <span class="o">=</span> <span class="p">(</span><span class="n">n_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">))[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
|
||||||
<span class="n">k_offsets</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">)</span>
|
|
||||||
|
|
||||||
<span class="c1"># Get the pointers for the first block of each. We will advance this pointer</span>
|
|
||||||
<span class="c1"># as we move in the K direction and accumulate.</span>
|
|
||||||
<span class="c1"># a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers</span>
|
|
||||||
<span class="n">a_ptrs</span> <span class="o">=</span> <span class="n">a_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">stride_am</span> <span class="o">*</span> <span class="n">m_offsets_a</span> <span class="o">+</span> <span class="n">stride_ak</span> <span class="o">*</span> <span class="n">k_offsets</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:])</span>
|
|
||||||
<span class="c1"># b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers</span>
|
|
||||||
<span class="n">b_ptrs</span> <span class="o">=</span> <span class="n">b_ptr</span> <span class="o">+</span> <span class="p">(</span><span class="n">stride_bk</span> <span class="o">*</span> <span class="n">k_offsets</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">stride_bn</span> <span class="o">*</span> <span class="n">n_offsets_b</span><span class="p">)</span>
|
|
||||||
<span class="c1"># We accumulate internally in fp32, but the output is written out in the dtype</span>
|
|
||||||
<span class="c1"># of the tensor when it is stored</span>
|
|
||||||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">BLOCK_SIZE_M</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tl</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||||||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
|
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">BLOCK_SIZE_K</span><span class="p">):</span>
|
||||||
<span class="c1"># Note that for simplicity, we don't apply a mask here. This means that if K is</span>
|
<span class="c1"># Note that for simplicity, we don't apply a mask here.</span>
|
||||||
<span class="c1"># not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and</span>
|
<span class="c1"># This means that if K is not a multiple of BLOCK_SIZE_K,</span>
|
||||||
<span class="c1"># accumulate it incorrectly.</span>
|
<span class="c1"># this will access out-of-bounds memory and produce an</span>
|
||||||
|
<span class="c1"># error or (worse!) incorrect results.</span>
|
||||||
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">a_ptrs</span><span class="p">)</span>
|
<span class="n">a</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">a_ptrs</span><span class="p">)</span>
|
||||||
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">b_ptrs</span><span class="p">)</span>
|
<span class="n">b</span> <span class="o">=</span> <span class="n">tl</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">b_ptrs</span><span class="p">)</span>
|
||||||
<span class="c1"># We accumulate along the K dimension</span>
|
<span class="c1"># We accumulate along the K dimension</span>
|
||||||
<span class="n">accumulator</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
<span class="n">accumulator</span> <span class="o">+=</span> <span class="n">tl</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||||||
|
|
||||||
<span class="c1"># Advance the ptrs to the next K block</span>
|
<span class="c1"># Advance the ptrs to the next K block</span>
|
||||||
<span class="n">a_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
<span class="n">a_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_ak</span>
|
||||||
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
<span class="n">b_ptrs</span> <span class="o">+=</span> <span class="n">BLOCK_SIZE_K</span> <span class="o">*</span> <span class="n">stride_bk</span>
|
||||||
<span class="c1"># triton can accept arbitrary activation function via metaparameters!</span>
|
<span class="c1"># you can fuse arbitrary activation functions here</span>
|
||||||
|
<span class="c1"># while the accumulator is still in FP32 !</span>
|
||||||
<span class="k">if</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">]:</span>
|
<span class="k">if</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">]:</span>
|
||||||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">](</span><span class="n">accumulator</span><span class="p">)</span>
|
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">'ACTIVATION'</span><span class="p">](</span><span class="n">accumulator</span><span class="p">)</span>
|
||||||
|
<span class="n">c</span> <span class="o">=</span> <span class="n">accumulator</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">tl</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||||
|
|
||||||
<span class="n">m_offsets_c</span> <span class="o">=</span> <span class="p">(</span><span class="n">m_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">))[:,</span> <span class="kc">None</span><span class="p">]</span>
|
<span class="c1"># -----------------------------------------------------------</span>
|
||||||
<span class="n">n_offsets_c</span> <span class="o">=</span> <span class="p">(</span><span class="n">n_start</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">))[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
<span class="c1"># Write back the block of the output matrix C</span>
|
||||||
<span class="n">c_ptrs</span> <span class="o">=</span> <span class="n">c_ptr</span> <span class="o">+</span> <span class="n">stride_cm</span> <span class="o">*</span> <span class="n">m_offsets_c</span> <span class="o">+</span> <span class="n">stride_cn</span> <span class="o">*</span> <span class="n">n_offsets_c</span>
|
<span class="n">offs_cm</span> <span class="o">=</span> <span class="n">pid_m</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_M</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_M</span><span class="p">)</span>
|
||||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">m_offsets_c</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">n_offsets_c</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
<span class="n">offs_cn</span> <span class="o">=</span> <span class="n">pid_n</span> <span class="o">*</span> <span class="n">BLOCK_SIZE_N</span> <span class="o">+</span> <span class="n">tl</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">BLOCK_SIZE_N</span><span class="p">)</span>
|
||||||
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">c_ptrs</span><span class="p">,</span> <span class="n">accumulator</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span>
|
<span class="n">c_ptrs</span> <span class="o">=</span> <span class="n">c_ptr</span> <span class="o">+</span> <span class="n">stride_cm</span> <span class="o">*</span> <span class="n">offs_cm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">stride_cn</span> <span class="o">*</span> <span class="n">offs_cn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
|
||||||
|
<span class="n">c_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">offs_cm</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o"><</span> <span class="n">M</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">offs_cn</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o"><</span> <span class="n">N</span><span class="p">)</span>
|
||||||
|
<span class="n">tl</span><span class="o">.</span><span class="n">store</span><span class="p">(</span><span class="n">c_ptrs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">c_mask</span><span class="p">)</span>
|
||||||
|
|
||||||
|
|
||||||
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
|
<span class="c1"># we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`</span>
|
||||||
@@ -469,18 +468,11 @@ and (1) checks any shape constraint; (2) allocates the output; (3) launches the
|
|||||||
<span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]),</span>
|
<span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_M'</span><span class="p">])</span> <span class="o">*</span> <span class="n">triton</span><span class="o">.</span><span class="n">cdiv</span><span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">META</span><span class="p">[</span><span class="s1">'BLOCK_SIZE_N'</span><span class="p">]),</span>
|
||||||
<span class="p">)</span>
|
<span class="p">)</span>
|
||||||
<span class="n">matmul_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
<span class="n">matmul_kernel</span><span class="p">[</span><span class="n">grid</span><span class="p">](</span>
|
||||||
<span class="n">a</span><span class="p">,</span>
|
<span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span>
|
||||||
<span class="n">b</span><span class="p">,</span>
|
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span>
|
||||||
<span class="n">c</span><span class="p">,</span>
|
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||||||
<span class="n">M</span><span class="p">,</span>
|
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||||||
<span class="n">N</span><span class="p">,</span>
|
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||||||
<span class="n">K</span><span class="p">,</span>
|
|
||||||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
||||||
<span class="n">a</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
|
||||||
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
||||||
<span class="n">b</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
|
||||||
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
||||||
<span class="n">c</span><span class="o">.</span><span class="n">stride</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
|
||||||
<span class="n">ACTIVATION</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span>
|
<span class="n">ACTIVATION</span><span class="o">=</span><span class="n">activation</span><span class="p">,</span>
|
||||||
<span class="p">)</span>
|
<span class="p">)</span>
|
||||||
<span class="k">return</span> <span class="n">c</span>
|
<span class="k">return</span> <span class="n">c</span>
|
||||||
@@ -575,42 +567,42 @@ torch_output=tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -3
|
|||||||
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
<div class="sphx-glr-script-out highlight-none notranslate"><div class="highlight"><pre><span></span>matmul-performance:
|
||||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||||
0 128.0 0.455111 ... 0.512000 0.512000
|
0 128.0 0.455111 ... 0.512000 0.512000
|
||||||
1 256.0 2.730667 ... 2.978909 2.978909
|
1 256.0 2.978909 ... 2.978909 2.978909
|
||||||
2 384.0 7.372800 ... 8.507077 8.507077
|
2 384.0 7.372800 ... 8.507077 8.507077
|
||||||
3 512.0 14.563555 ... 15.420235 16.384000
|
3 512.0 14.563555 ... 16.384000 16.384000
|
||||||
4 640.0 22.260869 ... 24.380953 23.272727
|
4 640.0 22.260869 ... 24.380953 24.380953
|
||||||
5 768.0 32.768000 ... 34.028308 34.028308
|
5 768.0 32.768000 ... 34.028308 34.028308
|
||||||
6 896.0 39.025776 ... 40.140799 39.025776
|
6 896.0 39.025776 ... 40.140799 36.023796
|
||||||
7 1024.0 49.932191 ... 53.773130 52.428801
|
7 1024.0 49.932191 ... 52.428801 52.428801
|
||||||
8 1152.0 45.242181 ... 46.656000 46.656000
|
8 1152.0 44.566925 ... 46.656000 46.656000
|
||||||
9 1280.0 51.200001 ... 56.888887 56.888887
|
9 1280.0 51.200001 ... 56.888887 56.109587
|
||||||
10 1408.0 64.138541 ... 64.902096 64.902096
|
10 1408.0 64.138541 ... 64.902096 64.902096
|
||||||
11 1536.0 78.643199 ... 76.106321 75.296679
|
11 1536.0 78.643199 ... 76.106321 75.296679
|
||||||
12 1664.0 62.929456 ... 62.061463 62.061463
|
12 1664.0 63.372618 ... 62.492442 61.636381
|
||||||
13 1792.0 72.983276 ... 69.810085 69.379162
|
13 1792.0 72.983276 ... 69.810085 69.379162
|
||||||
14 1920.0 67.434145 ... 70.892307 70.530615
|
14 1920.0 67.434145 ... 70.892307 70.530615
|
||||||
15 2048.0 73.908442 ... 74.898285 74.565406
|
15 2048.0 73.908442 ... 75.234154 74.898285
|
||||||
16 2176.0 83.500614 ... 78.916269 79.855747
|
16 2176.0 81.472263 ... 80.817862 80.173899
|
||||||
17 2304.0 68.251065 ... 73.275679 72.828879
|
17 2304.0 68.446623 ... 73.501144 73.275679
|
||||||
18 2432.0 71.125224 ... 80.731218 80.731218
|
18 2432.0 71.305746 ... 81.197876 79.362895
|
||||||
19 2560.0 77.649287 ... 76.560748 76.382283
|
19 2560.0 77.649287 ... 77.649287 76.560748
|
||||||
20 2688.0 81.928846 ... 80.366642 82.823267
|
20 2688.0 82.642823 ... 80.708630 82.823267
|
||||||
21 2816.0 77.743683 ... 78.868366 78.301990
|
21 2816.0 79.587973 ... 79.733474 77.605356
|
||||||
22 2944.0 81.832567 ... 79.610276 78.605729
|
22 2944.0 81.967162 ... 78.112900 79.230573
|
||||||
23 3072.0 81.005868 ... 81.005868 82.420822
|
23 3072.0 81.707223 ... 84.135370 79.863336
|
||||||
24 3200.0 84.321474 ... 89.635851 85.106381
|
24 3200.0 84.099871 ... 87.074829 89.136491
|
||||||
25 3328.0 83.226931 ... 87.156532 86.113988
|
25 3328.0 83.905938 ... 84.003845 86.424125
|
||||||
26 3456.0 81.932484 ... 83.632331 85.313831
|
26 3456.0 81.518272 ... 85.494768 81.353753
|
||||||
27 3584.0 87.211821 ... 87.211821 91.563533
|
27 3584.0 86.540320 ... 94.448944 94.847460
|
||||||
28 3712.0 85.896254 ... 82.491612 84.874549
|
28 3712.0 83.947349 ... 88.955779 89.114488
|
||||||
29 3840.0 85.070769 ... 87.493673 87.701820
|
29 3840.0 84.809814 ... 88.191387 87.217666
|
||||||
30 3968.0 92.935215 ... 83.865247 83.578035
|
30 3968.0 93.148045 ... 83.179234 87.409694
|
||||||
31 4096.0 93.662059 ... 85.926841 84.840533
|
31 4096.0 93.531519 ... 89.777746 87.552332
|
||||||
|
|
||||||
[32 rows x 5 columns]
|
[32 rows x 5 columns]
|
||||||
</pre></div>
|
</pre></div>
|
||||||
</div>
|
</div>
|
||||||
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 9.226 seconds)</p>
|
<p class="sphx-glr-timing"><strong>Total running time of the script:</strong> ( 2 minutes 30.498 seconds)</p>
|
||||||
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
<div class="sphx-glr-footer class sphx-glr-footer-example docutils container" id="sphx-glr-download-getting-started-tutorials-03-matrix-multiplication-py">
|
||||||
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
<div class="sphx-glr-download sphx-glr-download-python docutils container">
|
||||||
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
<p><a class="reference download internal" download="" href="../../_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py"><code class="xref download docutils literal notranslate"><span class="pre">Download</span> <span class="pre">Python</span> <span class="pre">source</span> <span class="pre">code:</span> <span class="pre">03-matrix-multiplication.py</span></code></a></p>
|
||||||
|
@@ -174,7 +174,7 @@
|
|||||||
|
|
||||||
<div class="section" id="computation-times">
|
<div class="section" id="computation-times">
|
||||||
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
<span id="sphx-glr-getting-started-tutorials-sg-execution-times"></span><h1>Computation times<a class="headerlink" href="#computation-times" title="Permalink to this headline">¶</a></h1>
|
||||||
<p><strong>03:32.837</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
<p><strong>03:54.208</strong> total execution time for <strong>getting-started_tutorials</strong> files:</p>
|
||||||
<table class="docutils align-default">
|
<table class="docutils align-default">
|
||||||
<colgroup>
|
<colgroup>
|
||||||
<col style="width: 85%" />
|
<col style="width: 85%" />
|
||||||
@@ -183,15 +183,15 @@
|
|||||||
</colgroup>
|
</colgroup>
|
||||||
<tbody>
|
<tbody>
|
||||||
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
<tr class="row-odd"><td><p><a class="reference internal" href="03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py"><span class="std std-ref">Matrix Multiplication</span></a> (<code class="docutils literal notranslate"><span class="pre">03-matrix-multiplication.py</span></code>)</p></td>
|
||||||
<td><p>02:09.226</p></td>
|
<td><p>02:30.498</p></td>
|
||||||
<td><p>0.0 MB</p></td>
|
<td><p>0.0 MB</p></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
<tr class="row-even"><td><p><a class="reference internal" href="02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py"><span class="std std-ref">Fused Softmax</span></a> (<code class="docutils literal notranslate"><span class="pre">02-fused-softmax.py</span></code>)</p></td>
|
||||||
<td><p>01:12.617</p></td>
|
<td><p>01:12.739</p></td>
|
||||||
<td><p>0.0 MB</p></td>
|
<td><p>0.0 MB</p></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
<tr class="row-odd"><td><p><a class="reference internal" href="01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py"><span class="std std-ref">Vector Addition</span></a> (<code class="docutils literal notranslate"><span class="pre">01-vector-add.py</span></code>)</p></td>
|
||||||
<td><p>00:10.994</p></td>
|
<td><p>00:10.971</p></td>
|
||||||
<td><p>0.0 MB</p></td>
|
<td><p>0.0 MB</p></td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
|