From 4182e90862e5514da42969f75284a60097730ae7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 24 Dec 2022 00:31:05 -0800 Subject: [PATCH] less math --- python/triton/compiler.py | 1 + python/tutorials/06-fused-attention.py | 60 +++++++++++++------------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 96ebfc3cf..99e4e5928 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1562,6 +1562,7 @@ class CompiledKernel: if self.shared > max_shared: raise OutOfResources(self.shared, max_shared, "shared memory") mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device) + print(n_regs, n_spills) self.cu_module = mod self.cu_function = func diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index a185724f7..1bd787aaa 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -15,7 +15,7 @@ import triton.language as tl @triton.jit def _fwd_kernel( Q, K, V, sm_scale, - TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug + L, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -39,51 +39,49 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # load q: it will stay in SRAM throughout q = tl.load(q_ptrs) # loop over k, v and update accumulator for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): - # start_n = tl.multiple_of(start_n, BLOCK_N) + start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + start_n * stride_kn) + k = tl.load(k_ptrs) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + # compute new m + m = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m) + # attention weights + p = tl.exp(qk - m[:, None]) + l = tl.sum(p, 1) + l_prev + l_rcp = 1. / l + # rescale operands of matmuls + p *= l_rcp + acc *= (l_prev * l_rcp)[:, None] # update acc - v = tl.load(v_ptrs + start_n * stride_vk) p = p.to(tl.float16) + v = tl.load(v_ptrs) acc += tl.dot(p, v) # update m_i and l_i - l_i = l_i_new - m_i = m_i_new + l_prev = l + m_prev = m + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk # rematerialize offsets to save registers start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # write back l and m l_ptrs = L + off_hz * N_CTX + offs_m m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(l_ptrs, l_i) - tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_prev) + tl.store(m_ptrs, m_prev) # initialize pointers to output offs_n = tl.arange(0, BLOCK_DMODEL) off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on @@ -207,14 +205,13 @@ class _attention(torch.autograd.Function): assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) - tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( q, k, v, sm_scale, - tmp, L, m, + L, m, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -222,7 +219,7 @@ class _attention(torch.autograd.Function): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, num_warps=4, + BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=2, ) @@ -336,6 +333,9 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5 + total_flops = 2*flops_per_matmul + print(total_flops/ms*1e-9) return ms if provider == "flash": lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)