[GH-PAGES] Updated website
This commit is contained in:
@@ -238,7 +238,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
3 32768.0 76.800002 76.800002
|
||||
4 65536.0 127.999995 127.999995
|
||||
5 131072.0 219.428568 219.428568
|
||||
6 262144.0 341.333321 384.000001
|
||||
6 262144.0 341.333321 341.333321
|
||||
7 524288.0 472.615390 472.615390
|
||||
8 1048576.0 614.400016 614.400016
|
||||
9 2097152.0 722.823517 722.823517
|
||||
@@ -255,7 +255,7 @@ We can now run the decorated function above. Pass `print_data=True` to see the p
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 1 minutes 34.829 seconds)
|
||||
**Total running time of the script:** ( 1 minutes 50.020 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_01-vector-add.py:
|
||||
|
@@ -278,17 +278,17 @@ We will then compare its performance against (1) :code:`torch.softmax` and (2) t
|
||||
|
||||
softmax-performance:
|
||||
N Triton Torch (native) Torch (jit)
|
||||
0 256.0 546.133347 512.000001 188.321838
|
||||
1 384.0 614.400016 585.142862 153.600004
|
||||
2 512.0 655.360017 585.142849 154.566038
|
||||
0 256.0 546.133347 546.133347 186.181817
|
||||
1 384.0 614.400016 585.142862 151.703707
|
||||
2 512.0 655.360017 606.814814 154.566038
|
||||
3 640.0 706.206879 640.000002 160.000000
|
||||
4 768.0 722.823517 664.216187 162.754967
|
||||
4 768.0 722.823517 664.216187 163.839992
|
||||
.. ... ... ... ...
|
||||
93 12160.0 812.359066 406.179533 198.936606
|
||||
94 12288.0 812.429770 415.222812 199.298541
|
||||
95 12416.0 812.498981 412.149375 198.954424
|
||||
96 12544.0 812.566838 412.758863 199.209928
|
||||
97 12672.0 811.007961 412.097543 199.264875
|
||||
94 12288.0 812.429770 415.222812 199.096718
|
||||
95 12416.0 812.498981 412.149375 198.854847
|
||||
96 12544.0 810.925276 412.971190 199.012395
|
||||
97 12672.0 811.007961 412.097543 199.167004
|
||||
|
||||
[98 rows x 4 columns]
|
||||
|
||||
@@ -306,7 +306,7 @@ In the above plot, we can see that:
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 3 minutes 18.076 seconds)
|
||||
**Total running time of the script:** ( 3 minutes 32.089 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py:
|
||||
|
@@ -459,37 +459,37 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
|
||||
matmul-performance:
|
||||
M cuBLAS ... Triton Triton (+ LeakyReLU)
|
||||
0 256.0 2.730667 ... 2.978909 2.978909
|
||||
0 256.0 2.730667 ... 3.276800 2.978909
|
||||
1 384.0 7.372800 ... 7.899428 7.899428
|
||||
2 512.0 14.563555 ... 15.420235 15.420235
|
||||
2 512.0 14.563555 ... 16.384000 15.420235
|
||||
3 640.0 22.260869 ... 24.380953 24.380953
|
||||
4 768.0 32.768000 ... 35.389441 34.028308
|
||||
5 896.0 37.971025 ... 40.140799 39.025776
|
||||
5 896.0 39.025776 ... 40.140799 39.025776
|
||||
6 1024.0 49.932191 ... 53.773130 52.428801
|
||||
7 1152.0 45.242181 ... 48.161033 47.396572
|
||||
7 1152.0 45.242181 ... 47.396572 47.396572
|
||||
8 1280.0 51.200001 ... 57.690139 57.690139
|
||||
9 1408.0 64.138541 ... 68.147202 65.684049
|
||||
10 1536.0 79.526831 ... 81.355034 78.643199
|
||||
11 1664.0 63.372618 ... 63.372618 62.492442
|
||||
9 1408.0 64.138541 ... 68.147202 66.485074
|
||||
10 1536.0 80.430545 ... 80.430545 78.643199
|
||||
11 1664.0 62.929456 ... 63.372618 62.492442
|
||||
12 1792.0 72.983276 ... 72.983276 59.154861
|
||||
13 1920.0 68.776119 ... 71.626943 70.892307
|
||||
14 2048.0 73.584279 ... 78.033565 76.959706
|
||||
15 2176.0 83.155572 ... 87.494120 86.367588
|
||||
16 2304.0 68.446623 ... 78.064941 77.057651
|
||||
17 2432.0 71.305746 ... 86.179335 85.393507
|
||||
18 2560.0 77.833728 ... 82.956960 81.715711
|
||||
19 2688.0 83.737433 ... 91.185232 89.464755
|
||||
20 2816.0 82.446516 ... 84.523664 83.712490
|
||||
21 2944.0 81.967162 ... 83.758038 82.373605
|
||||
22 3072.0 82.420822 ... 88.750943 86.579673
|
||||
23 3200.0 81.528664 ... 91.233074 95.665176
|
||||
24 3328.0 83.516586 ... 85.908470 83.323259
|
||||
25 3456.0 81.435930 ... 92.138932 90.180725
|
||||
26 3584.0 83.954614 ... 91.189190 95.858629
|
||||
27 3712.0 85.822459 ... 83.806497 87.783251
|
||||
28 3840.0 80.901241 ... 89.259080 89.548180
|
||||
29 3968.0 87.913500 ... 92.829164 84.096442
|
||||
30 4096.0 93.825748 ... 89.299883 90.139506
|
||||
13 1920.0 69.120002 ... 71.257735 71.257735
|
||||
14 2048.0 73.584279 ... 78.398206 77.314362
|
||||
15 2176.0 83.155572 ... 87.494120 85.998493
|
||||
16 2304.0 68.446623 ... 78.320893 77.558029
|
||||
17 2432.0 71.305746 ... 86.711310 75.421383
|
||||
18 2560.0 77.833728 ... 82.747477 81.715711
|
||||
19 2688.0 83.552988 ... 90.532356 89.464755
|
||||
20 2816.0 84.197315 ... 84.035084 84.035084
|
||||
21 2944.0 82.784108 ... 83.969728 83.060049
|
||||
22 3072.0 81.825298 ... 89.593522 88.473602
|
||||
23 3200.0 84.768213 ... 96.096095 95.808380
|
||||
24 3328.0 83.226931 ... 85.908470 84.596116
|
||||
25 3456.0 81.766291 ... 91.824110 91.097818
|
||||
26 3584.0 87.466332 ... 91.194972 94.847460
|
||||
27 3712.0 85.822459 ... 87.246590 87.860458
|
||||
28 3840.0 81.859361 ... 87.011801 90.168771
|
||||
29 3968.0 89.921841 ... 91.954739 85.271796
|
||||
30 4096.0 93.596744 ... 88.243079 90.382307
|
||||
|
||||
[31 rows x 5 columns]
|
||||
|
||||
@@ -499,7 +499,7 @@ We can now compare the performance of our kernel against that of cuBLAS. Here we
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 5 minutes 52.578 seconds)
|
||||
**Total running time of the script:** ( 7 minutes 13.827 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_03-matrix-multiplication.py:
|
||||
|
@@ -240,7 +240,7 @@ References
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.476 seconds)
|
||||
**Total running time of the script:** ( 0 minutes 0.279 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py:
|
||||
|
@@ -21,7 +21,7 @@
|
||||
Layer Normalization
|
||||
====================
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-312
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-316
|
||||
|
||||
|
||||
|
||||
@@ -40,34 +40,34 @@ Layer Normalization
|
||||
N Triton Torch Apex
|
||||
0 1024.0 585.142849 277.694907 468.114273
|
||||
1 1536.0 630.153868 323.368435 511.999982
|
||||
2 2048.0 682.666643 334.367358 520.126988
|
||||
3 2560.0 694.237267 365.714281 518.481028
|
||||
4 3072.0 712.347810 378.092307 501.551037
|
||||
5 3584.0 725.873439 384.859062 458.751978
|
||||
6 4096.0 728.177767 381.023256 458.293714
|
||||
7 4608.0 670.254540 396.387087 426.173427
|
||||
8 5120.0 694.237267 397.669909 426.666652
|
||||
9 5632.0 704.000002 396.969169 413.357796
|
||||
10 6144.0 702.171410 402.885254 411.313806
|
||||
2 2048.0 668.734716 337.814445 528.516136
|
||||
3 2560.0 694.237267 362.477870 512.000013
|
||||
4 3072.0 712.347810 375.206126 501.551037
|
||||
5 3584.0 725.873439 384.859062 451.527536
|
||||
6 4096.0 728.177767 381.023256 455.111095
|
||||
7 4608.0 670.254540 396.387087 421.302872
|
||||
8 5120.0 688.403381 395.748783 422.268057
|
||||
9 5632.0 698.542675 396.969169 409.599997
|
||||
10 6144.0 702.171410 402.885254 409.600010
|
||||
11 6656.0 700.631610 400.360920 400.360920
|
||||
12 7168.0 695.078767 396.844306 388.772874
|
||||
13 7680.0 682.666656 393.846167 387.634072
|
||||
14 8192.0 642.509816 393.609605 372.363633
|
||||
15 8704.0 627.315309 389.005597 380.502740
|
||||
16 9216.0 606.814809 407.337026 383.999986
|
||||
17 9728.0 589.575753 409.599987 383.369452
|
||||
18 10240.0 566.920437 408.578556 382.803739
|
||||
19 10752.0 549.623009 411.559798 381.445676
|
||||
20 11264.0 536.380957 406.826188 373.134567
|
||||
21 11776.0 523.377770 410.492372 377.587162
|
||||
22 12288.0 517.389457 414.784810 383.251457
|
||||
23 12800.0 505.679014 410.420828 376.470582
|
||||
24 13312.0 494.180982 405.699062 376.976995
|
||||
25 13824.0 482.934503 411.888257 379.389355
|
||||
26 14336.0 471.967074 406.695045 374.185964
|
||||
27 14848.0 461.297068 408.192434 375.304904
|
||||
28 15360.0 454.269882 406.214870 378.092307
|
||||
29 15872.0 447.887117 407.627589 376.225175
|
||||
12 7168.0 678.627194 386.154893 384.859062
|
||||
13 7680.0 682.666656 391.337574 386.415087
|
||||
14 8192.0 645.674867 390.095241 376.643677
|
||||
15 8704.0 624.502255 390.095225 379.465939
|
||||
16 9216.0 604.327881 405.098894 383.002605
|
||||
17 9728.0 585.142883 409.599987 382.427505
|
||||
18 10240.0 564.965524 409.600010 382.803739
|
||||
19 10752.0 546.133312 410.577576 380.601764
|
||||
20 11264.0 531.634232 395.228063 370.069806
|
||||
21 11776.0 520.486200 409.599991 376.831982
|
||||
22 12288.0 516.031509 413.911572 383.251457
|
||||
23 12800.0 504.433489 410.420828 375.779805
|
||||
24 13312.0 494.180982 405.699062 376.310952
|
||||
25 13824.0 481.882350 411.888257 378.739711
|
||||
26 14336.0 471.967074 401.709294 372.969090
|
||||
27 14848.0 461.297068 407.492270 375.898745
|
||||
28 15360.0 453.431739 406.887417 378.092307
|
||||
29 15872.0 447.098578 406.323209 376.225175
|
||||
|
||||
|
||||
|
||||
@@ -204,17 +204,19 @@ Layer Normalization
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
@@ -287,7 +289,15 @@ Layer Normalization
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
@@ -296,17 +306,11 @@ Layer Normalization
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
return (da, None, dweight, dbias, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
@@ -389,7 +393,7 @@ Layer Normalization
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 5 minutes 24.641 seconds)
|
||||
**Total running time of the script:** ( 5 minutes 32.552 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
|
||||
|
@@ -0,0 +1,416 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/06-fused-attention.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_06-fused-attention.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_06-fused-attention.py:
|
||||
|
||||
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 7-355
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
||||
assert Lq == Lk
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=64, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = 64
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['bwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# only works on A100 at the moment
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.072 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 06-fused-attention.py <06-fused-attention.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 06-fused-attention.ipynb <06-fused-attention.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -0,0 +1,183 @@
|
||||
|
||||
.. DO NOT EDIT.
|
||||
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
|
||||
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
|
||||
.. "getting-started/tutorials/07-libdevice-function.py"
|
||||
.. LINE NUMBERS ARE GIVEN BELOW.
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. note::
|
||||
:class: sphx-glr-download-link-note
|
||||
|
||||
Click :ref:`here <sphx_glr_download_getting-started_tutorials_07-libdevice-function.py>`
|
||||
to download the full example code
|
||||
|
||||
.. rst-class:: sphx-glr-example-title
|
||||
|
||||
.. _sphx_glr_getting-started_tutorials_07-libdevice-function.py:
|
||||
|
||||
|
||||
Libdevice function
|
||||
===============
|
||||
Triton can invoke a custom function from an external library.
|
||||
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
|
||||
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
|
||||
|
||||
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||
Using triton, you can simply call `tl.libdevice.asinf`.
|
||||
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 15-17
|
||||
|
||||
asin Kernel
|
||||
--------------------------
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 17-39
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def asin_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
x = tl.libdevice.asin(x)
|
||||
tl.store(y_ptr + offsets, x, mask=mask)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 40-43
|
||||
|
||||
Using the default libdevice library path
|
||||
--------------------------
|
||||
We can use the default libdevice library path encoded in `triton/language/libdevice.py`
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 43-61
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
output_triton = torch.zeros(size, device='cuda')
|
||||
output_torch = torch.asin(x)
|
||||
assert x.is_cuda and output_triton.is_cuda
|
||||
n_elements = output_torch.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
The maximum difference between torch and triton is 2.384185791015625e-07
|
||||
|
||||
|
||||
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 62-65
|
||||
|
||||
Customize the libdevice library path
|
||||
--------------------------
|
||||
We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 65-75
|
||||
|
||||
.. code-block:: default
|
||||
|
||||
|
||||
output_triton = torch.empty_like(x)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
|
||||
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-script-out
|
||||
|
||||
Out:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0')
|
||||
The maximum difference between torch and triton is 2.384185791015625e-07
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 0 minutes 0.501 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_07-libdevice-function.py:
|
||||
|
||||
|
||||
.. only :: html
|
||||
|
||||
.. container:: sphx-glr-footer
|
||||
:class: sphx-glr-footer-example
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-python
|
||||
|
||||
:download:`Download Python source code: 07-libdevice-function.py <07-libdevice-function.py>`
|
||||
|
||||
|
||||
|
||||
.. container:: sphx-glr-download sphx-glr-download-jupyter
|
||||
|
||||
:download:`Download Jupyter notebook: 07-libdevice-function.ipynb <07-libdevice-function.ipynb>`
|
||||
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. rst-class:: sphx-glr-signature
|
||||
|
||||
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
|
@@ -122,6 +122,48 @@ To install the dependencies for the tutorials:
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/05-layer-norm
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="Fused Attention">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_06-fused-attention_thumb.png
|
||||
:alt: Fused Attention
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/06-fused-attention
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="In trition/language/libdevice.py, we try to aggregate functions with the same computation but d...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. figure:: /getting-started/tutorials/images/thumb/sphx_glr_07-libdevice-function_thumb.png
|
||||
:alt: Libdevice function
|
||||
|
||||
:ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py`
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/getting-started/tutorials/07-libdevice-function
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-clear"></div>
|
||||
|
@@ -5,16 +5,20 @@
|
||||
|
||||
Computation times
|
||||
=================
|
||||
**16:10.599** total execution time for **getting-started_tutorials** files:
|
||||
**18:09.339** total execution time for **getting-started_tutorials** files:
|
||||
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 05:52.578 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_03-matrix-multiplication.py` (``03-matrix-multiplication.py``) | 07:13.827 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:24.641 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_05-layer-norm.py` (``05-layer-norm.py``) | 05:32.552 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:18.076 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_02-fused-softmax.py` (``02-fused-softmax.py``) | 03:32.089 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:34.829 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_01-vector-add.py` (``01-vector-add.py``) | 01:50.020 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.476 | 0.0 MB |
|
||||
| :ref:`sphx_glr_getting-started_tutorials_07-libdevice-function.py` (``07-libdevice-function.py``) | 00:00.501 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_04-low-memory-dropout.py` (``04-low-memory-dropout.py``) | 00:00.279 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
| :ref:`sphx_glr_getting-started_tutorials_06-fused-attention.py` (``06-fused-attention.py``) | 00:00.072 | 0.0 MB |
|
||||
+---------------------------------------------------------------------------------------------------------+-----------+--------+
|
||||
|
Reference in New Issue
Block a user