[TUTORIALS] Added more credits in fused attention tutorial

This commit is contained in:
Phil Tillet
2022-07-13 23:48:58 -07:00
parent 0a3f3d5f25
commit 5b04331dd2

View File

@@ -1,7 +1,8 @@
""" """
Fused Attention Fused Attention
=============== ===============
This is a Triton implementation of the Flash Attention algorithm (Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf) This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
""" """
import pytest import pytest
@@ -349,5 +350,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms return ms
# only works on A100 at the moment
bench_flash_attention.run(save_path='.', print_data=True) # bench_flash_attention.run(save_path='.', print_data=True)