diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 72675b577..9da192a8b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -44,13 +44,13 @@ def _fwd_kernel( acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # load q: it will stay in SRAM throughout q = tl.load(q_ptrs) + q *= (q.to(tl.float32) * sm_scale).to(tl.float16) # loop over k, v and update accumulator for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): # -- compute qk ---- k = tl.load(k_ptrs) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) - qk *= sm_scale qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # compute new m m = tl.maximum(tl.max(qk, 1), m_prev) @@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark( ylabel='ms', plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} -) for mode in ['fwd']] +) for mode in ['bwd']] @triton.testing.perf_report(configs)