[GH-PAGES] Updated website

This commit is contained in:
Philippe Tillet
2022-08-17 00:49:36 +00:00
parent d1343b5511
commit f20cbb2743
167 changed files with 326 additions and 311 deletions

View File

@@ -23,7 +23,7 @@ Fused Attention
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
.. GENERATED FROM PYTHON SOURCE LINES 7-355
.. GENERATED FROM PYTHON SOURCE LINES 7-360
@@ -233,13 +233,16 @@ This is a Triton implementation of the Flash Attention algorithm
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
@@ -250,14 +253,14 @@ This is a Triton implementation of the Flash Attention algorithm
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=64, num_warps=4,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = 64
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
@@ -274,6 +277,8 @@ This is a Triton implementation of the Flash Attention algorithm
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
@@ -286,7 +291,7 @@ This is a Triton implementation of the Flash Attention algorithm
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
@@ -385,7 +390,7 @@ This is a Triton implementation of the Flash Attention algorithm
.. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 0 minutes 0.078 seconds)
**Total running time of the script:** ( 0 minutes 0.073 seconds)
.. _sphx_glr_download_getting-started_tutorials_06-fused-attention.py: