[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-07-14 07:22:19 +00:00
parent 3e815114fd
commit d1c6625bfd
179 changed files with 2617 additions and 369 deletions

View File

@@ -21,7 +21,7 @@
Layer Normalization
====================
.. GENERATED FROM PYTHON SOURCE LINES 5-312
.. GENERATED FROM PYTHON SOURCE LINES 5-316
@@ -40,34 +40,34 @@ Layer Normalization
N Triton Torch Apex
0 1024.0 585.142849 277.694907 468.114273
1 1536.0 630.153868 323.368435 511.999982
2 2048.0 682.666643 334.367358 520.126988
3 2560.0 694.237267 365.714281 518.481028
4 3072.0 712.347810 378.092307 501.551037
5 3584.0 725.873439 384.859062 458.751978
6 4096.0 728.177767 381.023256 458.293714
7 4608.0 670.254540 396.387087 426.173427
8 5120.0 694.237267 397.669909 426.666652
9 5632.0 704.000002 396.969169 413.357796
10 6144.0 702.171410 402.885254 411.313806
2 2048.0 668.734716 337.814445 528.516136
3 2560.0 694.237267 362.477870 512.000013
4 3072.0 712.347810 375.206126 501.551037
5 3584.0 725.873439 384.859062 451.527536
6 4096.0 728.177767 381.023256 455.111095
7 4608.0 670.254540 396.387087 421.302872
8 5120.0 688.403381 395.748783 422.268057
9 5632.0 698.542675 396.969169 409.599997
10 6144.0 702.171410 402.885254 409.600010
11 6656.0 700.631610 400.360920 400.360920
12 7168.0 695.078767 396.844306 388.772874
13 7680.0 682.666656 393.846167 387.634072
14 8192.0 642.509816 393.609605 372.363633
15 8704.0 627.315309 389.005597 380.502740
16 9216.0 606.814809 407.337026 383.999986
17 9728.0 589.575753 409.599987 383.369452
18 10240.0 566.920437 408.578556 382.803739
19 10752.0 549.623009 411.559798 381.445676
20 11264.0 536.380957 406.826188 373.134567
21 11776.0 523.377770 410.492372 377.587162
22 12288.0 517.389457 414.784810 383.251457
23 12800.0 505.679014 410.420828 376.470582
24 13312.0 494.180982 405.699062 376.976995
25 13824.0 482.934503 411.888257 379.389355
26 14336.0 471.967074 406.695045 374.185964
27 14848.0 461.297068 408.192434 375.304904
28 15360.0 454.269882 406.214870 378.092307
29 15872.0 447.887117 407.627589 376.225175
12 7168.0 678.627194 386.154893 384.859062
13 7680.0 682.666656 391.337574 386.415087
14 8192.0 645.674867 390.095241 376.643677
15 8704.0 624.502255 390.095225 379.465939
16 9216.0 604.327881 405.098894 383.002605
17 9728.0 585.142883 409.599987 382.427505
18 10240.0 564.965524 409.600010 382.803739
19 10752.0 546.133312 410.577576 380.601764
20 11264.0 531.634232 395.228063 370.069806
21 11776.0 520.486200 409.599991 376.831982
22 12288.0 516.031509 413.911572 383.251457
23 12800.0 504.433489 410.420828 375.779805
24 13312.0 494.180982 405.699062 376.310952
25 13824.0 481.882350 411.888257 378.739711
26 14336.0 471.967074 401.709294 372.969090
27 14848.0 461.297068 407.492270 375.898745
28 15360.0 453.431739 406.887417 378.092307
29 15872.0 447.098578 406.323209 376.225175
@@ -204,17 +204,19 @@ Layer Normalization
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
@@ -287,7 +289,15 @@ Layer Normalization
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
# accumulate partial sums in separate kernel
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
@@ -296,17 +306,11 @@ Layer Normalization
dbias,
M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
)
return (da, None, dweight, dbias, None, None,
None, None, None, None,
None,
None, None, None,
None,
None, None, None,
None, None, None,
None, None, None)
return (da, None, dweight, dbias, None)
def layer_norm(a, normalized_shape, weight, bias, eps):
@@ -389,7 +393,7 @@ Layer Normalization
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 5 minutes 24.641 seconds)
**Total running time of the script:** ( 5 minutes 32.552 seconds)
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py: