[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-07-14 07:22:19 +00:00
parent 3e815114fd
commit d1c6625bfd
179 changed files with 2617 additions and 369 deletions

View File

@@ -238,7 +238,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
3 32768.0 76.800002 76.800002
4 65536.0 127.999995 127.999995
5 131072.0 219.428568 219.428568
6 262144.0 341.333321 384.000001
6 262144.0 341.333321 341.333321
7 524288.0 472.615390 472.615390
8 1048576.0 614.400016 614.400016
9 2097152.0 722.823517 722.823517
@@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 34.829 seconds)
**Total running time of the script:** ( 1 minutes 50.020 seconds)
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:

View File

@@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
softmax-performance:
N Triton Torch (native) Torch (jit)
0 256.0 546.133347 512.000001 188.321838
1 384.0 614.400016 585.142862 153.600004
2 512.0 655.360017 585.142849 154.566038
0 256.0 546.133347 546.133347 186.181817
1 384.0 614.400016 585.142862 151.703707
2 512.0 655.360017 606.814814 154.566038
3 640.0 706.206879 640.000002 160.000000
4 768.0 722.823517 664.216187 162.754967
4 768.0 722.823517 664.216187 163.839992
.. ... ... ... ...
93 12160.0 812.359066 406.179533 198.936606
94 12288.0 812.429770 415.222812 199.298541
95 12416.0 812.498981 412.149375 198.954424
96 12544.0 812.566838 412.758863 199.209928
97 12672.0 811.007961 412.097543 199.264875
94 12288.0 812.429770 415.222812 199.096718
95 12416.0 812.498981 412.149375 198.854847
96 12544.0 810.925276 412.971190 199.012395
97 12672.0 811.007961 412.097543 199.167004
[98 rows x 4 columns]
@@ -306,7 +306,7 @@ In the above plot, we can see that:
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 3 minutes 18.076 seconds)
**Total running time of the script:** ( 3 minutes 32.089 seconds)
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:

View File

@@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
matmul-performance:
M cuBLAS ... Triton Triton (+ LeakyReLU)
0 256.0 2.730667 ... 2.978909 2.978909
0 256.0 2.730667 ... 3.276800 2.978909
1 384.0 7.372800 ... 7.899428 7.899428
2 512.0 14.563555 ... 15.420235 15.420235
2 512.0 14.563555 ... 16.384000 15.420235
3 640.0 22.260869 ... 24.380953 24.380953
4 768.0 32.768000 ... 35.389441 34.028308
5 896.0 37.971025 ... 40.140799 39.025776
5 896.0 39.025776 ... 40.140799 39.025776
6 1024.0 49.932191 ... 53.773130 52.428801
7 1152.0 45.242181 ... 48.161033 47.396572
7 1152.0 45.242181 ... 47.396572 47.396572
8 1280.0 51.200001 ... 57.690139 57.690139
9 1408.0 64.138541 ... 68.147202 65.684049
10 1536.0 79.526831 ... 81.355034 78.643199
11 1664.0 63.372618 ... 63.372618 62.492442
9 1408.0 64.138541 ... 68.147202 66.485074
10 1536.0 80.430545 ... 80.430545 78.643199
11 1664.0 62.929456 ... 63.372618 62.492442
12 1792.0 72.983276 ... 72.983276 59.154861
13 1920.0 68.776119 ... 71.626943 70.892307
14 2048.0 73.584279 ... 78.033565 76.959706
15 2176.0 83.155572 ... 87.494120 86.367588
16 2304.0 68.446623 ... 78.064941 77.057651
17 2432.0 71.305746 ... 86.179335 85.393507
18 2560.0 77.833728 ... 82.956960 81.715711
19 2688.0 83.737433 ... 91.185232 89.464755
20 2816.0 82.446516 ... 84.523664 83.712490
21 2944.0 81.967162 ... 83.758038 82.373605
22 3072.0 82.420822 ... 88.750943 86.579673
23 3200.0 81.528664 ... 91.233074 95.665176
24 3328.0 83.516586 ... 85.908470 83.323259
25 3456.0 81.435930 ... 92.138932 90.180725
26 3584.0 83.954614 ... 91.189190 95.858629
27 3712.0 85.822459 ... 83.806497 87.783251
28 3840.0 80.901241 ... 89.259080 89.548180
29 3968.0 87.913500 ... 92.829164 84.096442
30 4096.0 93.825748 ... 89.299883 90.139506
13 1920.0 69.120002 ... 71.257735 71.257735
14 2048.0 73.584279 ... 78.398206 77.314362
15 2176.0 83.155572 ... 87.494120 85.998493
16 2304.0 68.446623 ... 78.320893 77.558029
17 2432.0 71.305746 ... 86.711310 75.421383
18 2560.0 77.833728 ... 82.747477 81.715711
19 2688.0 83.552988 ... 90.532356 89.464755
20 2816.0 84.197315 ... 84.035084 84.035084
21 2944.0 82.784108 ... 83.969728 83.060049
22 3072.0 81.825298 ... 89.593522 88.473602
23 3200.0 84.768213 ... 96.096095 95.808380
24 3328.0 83.226931 ... 85.908470 84.596116
25 3456.0 81.766291 ... 91.824110 91.097818
26 3584.0 87.466332 ... 91.194972 94.847460
27 3712.0 85.822459 ... 87.246590 87.860458
28 3840.0 81.859361 ... 87.011801 90.168771
29 3968.0 89.921841 ... 91.954739 85.271796
30 4096.0 93.596744 ... 88.243079 90.382307
[31 rows x 5 columns]
@@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 52.578 seconds)
**Total running time of the script:** ( 7 minutes 13.827 seconds)
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:

View File

@@ -240,7 +240,7 @@ References
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.476 seconds)
**Total running time of the script:** ( 0 minutes 0.279 seconds)
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:

View File

@@ -21,7 +21,7 @@
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-312
.. GENERATED FROM PYTHON SOURCE LINES 5-316
@@ -40,34 +40,34 @@ Layer Normalization
N Triton Torch Apex
0 1024.0 585.142849 277.694907 468.114273
1 1536.0 630.153868 323.368435 511.999982
2 2048.0 682.666643 334.367358 520.126988
3 2560.0 694.237267 365.714281 518.481028
4 3072.0 712.347810 378.092307 501.551037
5 3584.0 725.873439 384.859062 458.751978
6 4096.0 728.177767 381.023256 458.293714
7 4608.0 670.254540 396.387087 426.173427
8 5120.0 694.237267 397.669909 426.666652
9 5632.0 704.000002 396.969169 413.357796
10 6144.0 702.171410 402.885254 411.313806
2 2048.0 668.734716 337.814445 528.516136
3 2560.0 694.237267 362.477870 512.000013
4 3072.0 712.347810 375.206126 501.551037
5 3584.0 725.873439 384.859062 451.527536
6 4096.0 728.177767 381.023256 455.111095
7 4608.0 670.254540 396.387087 421.302872
8 5120.0 688.403381 395.748783 422.268057
9 5632.0 698.542675 396.969169 409.599997
10 6144.0 702.171410 402.885254 409.600010
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 695.078767 396.844306 388.772874
13 7680.0 682.666656 393.846167 387.634072
14 8192.0 642.509816 393.609605 372.363633
15 8704.0 627.315309 389.005597 380.502740
16 9216.0 606.814809 407.337026 383.999986
17 9728.0 589.575753 409.599987 383.369452
18 10240.0 566.920437 408.578556 382.803739
19 10752.0 549.623009 411.559798 381.445676
20 11264.0 536.380957 406.826188 373.134567
21 11776.0 523.377770 410.492372 377.587162
22 12288.0 517.389457 414.784810 383.251457
23 12800.0 505.679014 410.420828 376.470582
24 13312.0 494.180982 405.699062 376.976995
25 13824.0 482.934503 411.888257 379.389355
26 14336.0 471.967074 406.695045 374.185964
27 14848.0 461.297068 408.192434 375.304904
28 15360.0 454.269882 406.214870 378.092307
29 15872.0 447.887117 407.627589 376.225175
12 7168.0 678.627194 386.154893 384.859062
13 7680.0 682.666656 391.337574 386.415087
14 8192.0 645.674867 390.095241 376.643677
15 8704.0 624.502255 390.095225 379.465939
16 9216.0 604.327881 405.098894 383.002605
17 9728.0 585.142883 409.599987 382.427505
18 10240.0 564.965524 409.600010 382.803739
19 10752.0 546.133312 410.577576 380.601764
20 11264.0 531.634232 395.228063 370.069806
21 11776.0 520.486200 409.599991 376.831982
22 12288.0 516.031509 413.911572 383.251457
23 12800.0 504.433489 410.420828 375.779805
24 13312.0 494.180982 405.699062 376.310952
25 13824.0 481.882350 411.888257 378.739711
26 14336.0 471.967074 401.709294 372.969090
27 14848.0 461.297068 407.492270 375.898745
28 15360.0 453.431739 406.887417 378.092307
29 15872.0 447.098578 406.323209 376.225175
@@ -204,17 +204,19 @@ Layer Normalization
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -287,7 +289,15 @@ Layer Normalization
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
# accumulate partial sums in separate kernel
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -296,17 +306,11 @@ Layer Normalization
dbias,
M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
)
return (da, None, dweight, dbias, None, None,
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):
@@ -389,7 +393,7 @@ Layer Normalization
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 24.641 seconds)
**Total running time of the script:** ( 5 minutes 32.552 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:

View File

@@ -0,0 +1,416 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/06-fused-attention.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_06-fused-attention.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_06-fused-attention.py:
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
.. GENERATED FROM PYTHON SOURCE LINES 7-355
.. code-block:: default
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=64, num_warps=4,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = 64
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.072 seconds)
.. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 06-fused-attention.py <06-fused-attention.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 06-fused-attention.ipynb <06-fused-attention.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -0,0 +1,183 @@
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "getting-started/tutorials/07-libdevice-function.py"
.. LINE NUMBERS ARE GIVEN BELOW.
.. only:: html
.. note::
:class: sphx-glr-download-link-note
Click :ref:`here <sphx_glr_download_getting-started_tutorials_07-libdevice-function.py>`
to download the full example code
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_getting-started_tutorials_07-libdevice-function.py:
Libdevice function
===============
Triton can invoke a custom function from an external library.
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
.. GENERATED FROM PYTHON SOURCE LINES 15-17
asin Kernel
--------------------------
.. GENERATED FROM PYTHON SOURCE LINES 17-39
.. code-block:: default
import torch
import triton
import triton.language as tl
@triton.jit
def asin_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x = tl.libdevice.asin(x)
tl.store(y_ptr + offsets, x, mask=mask)
.. GENERATED FROM PYTHON SOURCE LINES 40-43
Using the default libdevice library path
--------------------------
We can use the default libdevice library path encoded in `triton/language/libdevice.py`
.. GENERATED FROM PYTHON SOURCE LINES 43-61
.. code-block:: default
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
The maximum difference between torch and triton is 2.384185791015625e-07
.. GENERATED FROM PYTHON SOURCE LINES 62-65
Customize the libdevice library path
--------------------------
We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
.. GENERATED FROM PYTHON SOURCE LINES 65-75
.. code-block:: default
output_triton = torch.empty_like(x)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
.. rst-class:: sphx-glr-script-out
Out:
.. code-block:: none
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
The maximum difference between torch and triton is 2.384185791015625e-07
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.501 seconds)
.. _sphx_glr_download_getting-started_tutorials_07-libdevice-function.py:
.. only :: html
.. container:: sphx-glr-footer
:class: sphx-glr-footer-example
.. container:: sphx-glr-download sphx-glr-download-python
:download:`Download Python source code: 07-libdevice-function.py <07-libdevice-function.py>`
.. container:: sphx-glr-download sphx-glr-download-jupyter
:download:`Download Jupyter notebook: 07-libdevice-function.ipynb <07-libdevice-function.ipynb>`
.. only:: html
.. rst-class:: sphx-glr-signature
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_

View File

@@ -122,6 +122,48 @@ To install the dependencies for the tutorials:
:hidden:
/getting-started/tutorials/05-layer-norm
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="Fused Attention">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_06-fused-attention_thumb.png
:alt: Fused Attention
:ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/06-fused-attention
.. raw:: html
<div class="sphx-glr-thumbcontainer" tooltip="In trition/language/libdevice.py, we try to aggregate functions with the same computation but d...">
.. only:: html
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_07-libdevice-function_thumb.png
:alt: Libdevice function
:ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py`
.. raw:: html
</div>
.. toctree::
:hidden:
/getting-started/tutorials/07-libdevice-function
.. raw:: html
<div class="sphx-glr-clear"></div>

View File

@@ -5,16 +5,20 @@
Computation times
=================
**16:10.599** total execution time for **getting-started_tutorials** files:
**18:09.339** total execution time for **getting-started_tutorials** files:
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 05:52.578 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 07:13.827 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:24.641 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:32.552 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:18.076 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:32.089 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:34.829 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:50.020 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.476 | 0.0 MB |
| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.501 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.279 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.072 | 0.0 MB |
+---------------------------------------------------------------------------------------------------------+-----------+--------+