trying more things

This commit is contained in:
Phil Tillet
2022-12-27 20:58:31 -08:00
parent 4182e90862
commit 0d6e6cf578
6 changed files with 63 additions and 35 deletions

View File

@@ -194,6 +194,7 @@ def _bwd_kernel(
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
@@ -220,7 +221,7 @@ class _attention(torch.autograd.Function):
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
@@ -335,7 +336,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
total_flops = 2*flops_per_matmul
print(total_flops/ms*1e-9)
# print(total_flops/ms*1e-9)
print(ms)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)