[GH-PAGES] Updated website
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
Layer Normalization
|
||||
====================
|
||||
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-312
|
||||
.. GENERATED FROM PYTHON SOURCE LINES 5-316
|
||||
|
||||
|
||||
|
||||
@@ -40,34 +40,34 @@ Layer Normalization
|
||||
N Triton Torch Apex
|
||||
0 1024.0 585.142849 277.694907 468.114273
|
||||
1 1536.0 630.153868 323.368435 511.999982
|
||||
2 2048.0 682.666643 334.367358 520.126988
|
||||
3 2560.0 694.237267 365.714281 518.481028
|
||||
4 3072.0 712.347810 378.092307 501.551037
|
||||
5 3584.0 725.873439 384.859062 458.751978
|
||||
6 4096.0 728.177767 381.023256 458.293714
|
||||
7 4608.0 670.254540 396.387087 426.173427
|
||||
8 5120.0 694.237267 397.669909 426.666652
|
||||
9 5632.0 704.000002 396.969169 413.357796
|
||||
10 6144.0 702.171410 402.885254 411.313806
|
||||
2 2048.0 668.734716 337.814445 528.516136
|
||||
3 2560.0 694.237267 362.477870 512.000013
|
||||
4 3072.0 712.347810 375.206126 501.551037
|
||||
5 3584.0 725.873439 384.859062 451.527536
|
||||
6 4096.0 728.177767 381.023256 455.111095
|
||||
7 4608.0 670.254540 396.387087 421.302872
|
||||
8 5120.0 688.403381 395.748783 422.268057
|
||||
9 5632.0 698.542675 396.969169 409.599997
|
||||
10 6144.0 702.171410 402.885254 409.600010
|
||||
11 6656.0 700.631610 400.360920 400.360920
|
||||
12 7168.0 695.078767 396.844306 388.772874
|
||||
13 7680.0 682.666656 393.846167 387.634072
|
||||
14 8192.0 642.509816 393.609605 372.363633
|
||||
15 8704.0 627.315309 389.005597 380.502740
|
||||
16 9216.0 606.814809 407.337026 383.999986
|
||||
17 9728.0 589.575753 409.599987 383.369452
|
||||
18 10240.0 566.920437 408.578556 382.803739
|
||||
19 10752.0 549.623009 411.559798 381.445676
|
||||
20 11264.0 536.380957 406.826188 373.134567
|
||||
21 11776.0 523.377770 410.492372 377.587162
|
||||
22 12288.0 517.389457 414.784810 383.251457
|
||||
23 12800.0 505.679014 410.420828 376.470582
|
||||
24 13312.0 494.180982 405.699062 376.976995
|
||||
25 13824.0 482.934503 411.888257 379.389355
|
||||
26 14336.0 471.967074 406.695045 374.185964
|
||||
27 14848.0 461.297068 408.192434 375.304904
|
||||
28 15360.0 454.269882 406.214870 378.092307
|
||||
29 15872.0 447.887117 407.627589 376.225175
|
||||
12 7168.0 678.627194 386.154893 384.859062
|
||||
13 7680.0 682.666656 391.337574 386.415087
|
||||
14 8192.0 645.674867 390.095241 376.643677
|
||||
15 8704.0 624.502255 390.095225 379.465939
|
||||
16 9216.0 604.327881 405.098894 383.002605
|
||||
17 9728.0 585.142883 409.599987 382.427505
|
||||
18 10240.0 564.965524 409.600010 382.803739
|
||||
19 10752.0 546.133312 410.577576 380.601764
|
||||
20 11264.0 531.634232 395.228063 370.069806
|
||||
21 11776.0 520.486200 409.599991 376.831982
|
||||
22 12288.0 516.031509 413.911572 383.251457
|
||||
23 12800.0 504.433489 410.420828 375.779805
|
||||
24 13312.0 494.180982 405.699062 376.310952
|
||||
25 13824.0 481.882350 411.888257 378.739711
|
||||
26 14336.0 471.967074 401.709294 372.969090
|
||||
27 14848.0 461.297068 407.492270 375.898745
|
||||
28 15360.0 453.431739 406.887417 378.092307
|
||||
29 15872.0 447.098578 406.323209 376.225175
|
||||
|
||||
|
||||
|
||||
@@ -204,17 +204,19 @@ Layer Normalization
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
@@ -287,7 +289,15 @@ Layer Normalization
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
# accumulate partial sums in separate kernel
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
@@ -296,17 +306,11 @@ Layer Normalization
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None, None,
|
||||
None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None)
|
||||
return (da, None, dweight, dbias, None)
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
@@ -389,7 +393,7 @@ Layer Normalization
|
||||
|
||||
.. rst-class:: sphx-glr-timing
|
||||
|
||||
**Total running time of the script:** ( 5 minutes 24.641 seconds)
|
||||
**Total running time of the script:** ( 5 minutes 32.552 seconds)
|
||||
|
||||
|
||||
.. _sphx_glr_download_getting-started_tutorials_05-layer-norm.py:
|
||||
|
Reference in New Issue
Block a user