From 71e3143eaf815644bf0258ec9b3d4858c8fa1b7f Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 29 Dec 2022 14:40:27 -0800 Subject: [PATCH] . --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 ++-- python/tutorials/06-fused-attention.py | 49 +++++++++++++------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c13fcde86..0a50db460 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -104,8 +104,10 @@ SmallVector getSizePerThread(const Attribute &layout) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); } else if (auto sliceLayout = layout.dyn_cast()) { - return {1}; - return getSizePerThread(sliceLayout.getParent()); + auto ret = getSizePerThread(sliceLayout.getParent()); + return ret; + // ret.erase(ret.begin() + sliceLayout.getDim()); + return ret; } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isAmpere()) { return {2, 2}; diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index bf3405335..333a503a5 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -162,26 +162,26 @@ def _bwd_kernel( q = tl.load(q_ptrs) # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, k, trans_b=True) + qk = tl.dot(q, tl.trans(k)) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) m = tl.load(m_ptrs + offs_m_curr) p = tl.exp(qk * sm_scale - m[:, None]) # compute dv do = tl.load(do_ptrs) - dv += tl.dot(p.to(tl.float16), do, trans_a=True) + dv += tl.dot(tl.trans(p.to(tl.float16)), do) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, v, trans_b=True) + dp += tl.dot(do, tl.trans(v)) # compute ds = p * (dp - delta[:, None]) ds = p * dp * sm_scale # compute dk = dot(ds.T, q) - dk += tl.dot(ds.to(tl.float16), q, trans_a=True) - # # compute dq + dk += tl.dot(tl.trans(ds.to(tl.float16)), q) + # compute dq dq = tl.load(dq_ptrs) dq += tl.dot(ds.to(tl.float16), k) tl.store(dq_ptrs, dq) - # # increment pointers + # increment pointers dq_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm @@ -191,7 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) -# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4) +# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) empty = torch.empty(128, device="cuda") @@ -210,7 +210,7 @@ class _attention(torch.autograd.Function): m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( + h = _fwd_kernel[grid]( q, k, v, sm_scale, L, m, o, @@ -255,8 +255,7 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 - _bwd_kernel[(ctx.grid[1],)]( + pgm = _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled, dq, dk, dv, @@ -268,7 +267,7 @@ class _attention(torch.autograd.Function): q.shape[0], q.shape[1], q.shape[2], ctx.grid[0], BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, num_stages=1, ) return dq, dk, dv, None @@ -281,8 +280,8 @@ attention = _attention.apply def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.1).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.1).requires_grad_() sm_scale = 0.2 dout = torch.randn_like(q) # reference implementation @@ -294,23 +293,23 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): p = torch.softmax(p.float(), dim=-1).half() # p = torch.exp(p) ref_out = torch.matmul(p, v) - # ref_out.backward(dout) - # ref_dv, v.grad = v.grad.clone(), None - # ref_dk, k.grad = k.grad.clone(), None - # ref_dq, q.grad = q.grad.clone(), None + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None # # triton implementation tri_out = attention(q, k, v, sm_scale) # print(ref_out) # print(tri_out) - # tri_out.backward(dout) - # tri_dv, v.grad = v.grad.clone(), None - # tri_dk, k.grad = k.grad.clone(), None - # tri_dq, q.grad = q.grad.clone(), None + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None # compare triton.testing.assert_almost_equal(ref_out, tri_out) - # triton.testing.assert_almost_equal(ref_dv, tri_dv) - # triton.testing.assert_almost_equal(ref_dk, tri_dk) - # triton.testing.assert_almost_equal(ref_dq, tri_dq) + triton.testing.assert_almost_equal(ref_dv, tri_dv) + triton.testing.assert_almost_equal(ref_dk, tri_dk) + triton.testing.assert_almost_equal(ref_dq, tri_dq) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 @@ -361,4 +360,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms -bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file +# bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file