diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 89aadb1b4..c19ee498a 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -1,7 +1,8 @@ """ Fused Attention =============== -This is a Triton implementation of the Flash Attention algorithm (Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf) +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) """ import pytest @@ -349,5 +350,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms - -bench_flash_attention.run(save_path='.', print_data=True) +# only works on A100 at the moment +# bench_flash_attention.run(save_path='.', print_data=True)