.
This commit is contained in:
@@ -905,7 +905,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
|||||||
pm.add_licm_pass()
|
pm.add_licm_pass()
|
||||||
pm.add_tritongpu_combine_pass(compute_capability)
|
pm.add_tritongpu_combine_pass(compute_capability)
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
# pm.add_tritongpu_optimize_load_convert_pass()
|
pm.add_tritongpu_optimize_load_convert_pass()
|
||||||
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
|
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.add_symbol_dce_pass()
|
pm.add_symbol_dce_pass()
|
||||||
|
@@ -44,13 +44,13 @@ def _fwd_kernel(
|
|||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
# load q: it will stay in SRAM throughout
|
# load q: it will stay in SRAM throughout
|
||||||
q = tl.load(q_ptrs)
|
q = tl.load(q_ptrs)
|
||||||
q *= (q.to(tl.float32) * sm_scale).to(tl.float16)
|
|
||||||
# loop over k, v and update accumulator
|
# loop over k, v and update accumulator
|
||||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(k_ptrs)
|
k = tl.load(k_ptrs)
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
|
qk *= sm_scale
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
# compute new m
|
# compute new m
|
||||||
m = tl.maximum(tl.max(qk, 1), m_prev)
|
m = tl.maximum(tl.max(qk, 1), m_prev)
|
||||||
@@ -337,7 +337,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
|||||||
# vary seq length for fixed head and batch=4
|
# vary seq length for fixed head and batch=4
|
||||||
configs = [triton.testing.Benchmark(
|
configs = [triton.testing.Benchmark(
|
||||||
x_names=['N_CTX'],
|
x_names=['N_CTX'],
|
||||||
x_vals=[2**i for i in range(10, 13)],
|
x_vals=[2**i for i in range(10, 17)],
|
||||||
line_arg='provider',
|
line_arg='provider',
|
||||||
line_vals=['triton'],
|
line_vals=['triton'],
|
||||||
line_names=['Triton'],
|
line_names=['Triton'],
|
||||||
@@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark(
|
|||||||
ylabel='ms',
|
ylabel='ms',
|
||||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||||
) for mode in ['bwd']]
|
) for mode in ['fwd', 'bwd']]
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
@triton.testing.perf_report(configs)
|
||||||
@@ -367,7 +367,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
|
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
|
||||||
total_flops = 2 * flops_per_matmul
|
total_flops = 2 * flops_per_matmul
|
||||||
# print(total_flops/ms*1e-9)
|
# print(total_flops/ms*1e-9)
|
||||||
print(ms)
|
# print(ms)
|
||||||
return ms
|
return ms
|
||||||
if provider == "flash":
|
if provider == "flash":
|
||||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||||
@@ -383,4 +383,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
return ms
|
return ms
|
||||||
|
|
||||||
|
|
||||||
bench_flash_attention.run(save_path='.', print_data=True)
|
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||||
|
Reference in New Issue
Block a user