This commit is contained in:
Philippe Tillet
2023-01-09 23:11:51 -08:00
parent b162c44d59
commit 66fa2f2975
2 changed files with 6 additions and 6 deletions

View File

@@ -905,7 +905,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm.add_licm_pass() pm.add_licm_pass()
pm.add_tritongpu_combine_pass(compute_capability) pm.add_tritongpu_combine_pass(compute_capability)
pm.add_cse_pass() pm.add_cse_pass()
# pm.add_tritongpu_optimize_load_convert_pass() pm.add_tritongpu_optimize_load_convert_pass()
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass() pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
pm.add_cse_pass() pm.add_cse_pass()
pm.add_symbol_dce_pass() pm.add_symbol_dce_pass()

View File

@@ -44,13 +44,13 @@ def _fwd_kernel(
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout # load q: it will stay in SRAM throughout
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
q *= (q.to(tl.float32) * sm_scale).to(tl.float16)
# loop over k, v and update accumulator # loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ---- # -- compute qk ----
k = tl.load(k_ptrs) k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m # compute new m
m = tl.maximum(tl.max(qk, 1), m_prev) m = tl.maximum(tl.max(qk, 1), m_prev)
@@ -337,7 +337,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4 # vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark( configs = [triton.testing.Benchmark(
x_names=['N_CTX'], x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 13)], x_vals=[2**i for i in range(10, 17)],
line_arg='provider', line_arg='provider',
line_vals=['triton'], line_vals=['triton'],
line_names=['Triton'], line_names=['Triton'],
@@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms', ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']] ) for mode in ['fwd', 'bwd']]
@triton.testing.perf_report(configs) @triton.testing.perf_report(configs)
@@ -367,7 +367,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
# print(total_flops/ms*1e-9) # print(total_flops/ms*1e-9)
print(ms) # print(ms)
return ms return ms
if provider == "flash": if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
@@ -383,4 +383,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
return ms return ms
bench_flash_attention.run(save_path='.', print_data=True) # bench_flash_attention.run(save_path='.', print_data=True)