[TUTORIALS] Added more credits in fused attention tutorial
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Fused Attention
|
Fused Attention
|
||||||
===============
|
===============
|
||||||
This is a Triton implementation of the Flash Attention algorithm (Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf)
|
This is a Triton implementation of the Flash Attention algorithm
|
||||||
|
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -349,5 +350,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||||
return ms
|
return ms
|
||||||
|
|
||||||
|
# only works on A100 at the moment
|
||||||
bench_flash_attention.run(save_path='.', print_data=True)
|
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||||
|
Reference in New Issue
Block a user