This commit is contained in:
Phil Tillet
2023-01-04 11:25:03 -08:00
parent e70e1e76b4
commit 36da342893

View File

@@ -331,7 +331,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4 # vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark( configs = [triton.testing.Benchmark(
x_names=['N_CTX'], x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 11)], x_vals=[2**i for i in range(10, 15)],
line_arg='provider', line_arg='provider',
line_vals=['triton'], line_vals=['triton'],
line_names=['Triton'], line_names=['Triton'],
@@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms return ms
# bench_flash_attention.run(save_path='.', print_data=True) bench_flash_attention.run(save_path='.', print_data=True)