diff --git a/python/triton/compiler.py b/python/triton/compiler.py index b40fe65fb..bf102b612 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -905,7 +905,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): pm.add_licm_pass() pm.add_tritongpu_combine_pass(compute_capability) pm.add_cse_pass() - # pm.add_tritongpu_optimize_load_convert_pass() + pm.add_tritongpu_optimize_load_convert_pass() pm.add_tritongpu_decompose_conversions_to_dot_operand_pass() pm.add_cse_pass() pm.add_symbol_dce_pass() diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 9da192a8b..e7455903e 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -44,13 +44,13 @@ def _fwd_kernel( acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # load q: it will stay in SRAM throughout q = tl.load(q_ptrs) - q *= (q.to(tl.float32) * sm_scale).to(tl.float16) # loop over k, v and update accumulator for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): # -- compute qk ---- k = tl.load(k_ptrs) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) + qk *= sm_scale qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # compute new m m = tl.maximum(tl.max(qk, 1), m_prev) @@ -337,7 +337,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 configs = [triton.testing.Benchmark( x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 13)], + x_vals=[2**i for i in range(10, 17)], line_arg='provider', line_vals=['triton'], line_names=['Triton'], @@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark( ylabel='ms', plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} -) for mode in ['bwd']] +) for mode in ['fwd', 'bwd']] @triton.testing.perf_report(configs) @@ -367,7 +367,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 total_flops = 2 * flops_per_matmul # print(total_flops/ms*1e-9) - print(ms) + # print(ms) return ms if provider == "flash": lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) @@ -383,4 +383,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f return ms -bench_flash_attention.run(save_path='.', print_data=True) +# bench_flash_attention.run(save_path='.', print_data=True)