This commit is contained in:
Phil Tillet
2022-12-29 14:40:27 -08:00
parent 263ad883a6
commit 71e3143eaf
2 changed files with 28 additions and 27 deletions

View File

@@ -104,8 +104,10 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(), return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end()); blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) { } else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return {1}; auto ret = getSizePerThread(sliceLayout.getParent());
return getSizePerThread(sliceLayout.getParent()); return ret;
// ret.erase(ret.begin() + sliceLayout.getDim());
return ret;
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) { } else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) { if (mmaLayout.isAmpere()) {
return {2, 2}; return {2, 2};

View File

@@ -162,26 +162,26 @@ def _bwd_kernel(
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T # recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here # NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True) qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr) m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None]) p = tl.exp(qk * sm_scale - m[:, None])
# compute dv # compute dv
do = tl.load(do_ptrs) do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True) dv += tl.dot(tl.trans(p.to(tl.float16)), do)
# compute dp = dot(v, do) # compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr) Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True) dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None]) # compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale ds = p * dp * sm_scale
# compute dk = dot(ds.T, q) # compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True) dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
# # compute dq # compute dq
dq = tl.load(dq_ptrs) dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k) dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq) tl.store(dq_ptrs, dq)
# # increment pointers # increment pointers
dq_ptrs += BLOCK_M * stride_qm dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm
@@ -191,7 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv) tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk) tl.store(dk_ptrs, dk)
# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
empty = torch.empty(128, device="cuda") empty = torch.empty(128, device="cuda")
@@ -210,7 +210,7 @@ class _attention(torch.autograd.Function):
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid]( h = _fwd_kernel[grid](
q, k, v, sm_scale, q, k, v, sm_scale,
L, m, L, m,
o, o,
@@ -255,8 +255,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
) )
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 pgm = _bwd_kernel[(ctx.grid[1],)](
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale, q, k, v, ctx.sm_scale,
o, do_scaled, o, do_scaled,
dq, dk, dv, dq, dk, dv,
@@ -268,7 +267,7 @@ class _attention(torch.autograd.Function):
q.shape[0], q.shape[1], q.shape[2], q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0], ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1, num_stages=1,
) )
return dq, dk, dv, None return dq, dk, dv, None
@@ -281,8 +280,8 @@ attention = _attention.apply
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20) torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_() q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.1).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.1).requires_grad_()
sm_scale = 0.2 sm_scale = 0.2
dout = torch.randn_like(q) dout = torch.randn_like(q)
# reference implementation # reference implementation
@@ -294,23 +293,23 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
p = torch.softmax(p.float(), dim=-1).half() p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p) # p = torch.exp(p)
ref_out = torch.matmul(p, v) ref_out = torch.matmul(p, v)
# ref_out.backward(dout) ref_out.backward(dout)
# ref_dv, v.grad = v.grad.clone(), None ref_dv, v.grad = v.grad.clone(), None
# ref_dk, k.grad = k.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None
# ref_dq, q.grad = q.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None
# # triton implementation # # triton implementation
tri_out = attention(q, k, v, sm_scale) tri_out = attention(q, k, v, sm_scale)
# print(ref_out) # print(ref_out)
# print(tri_out) # print(tri_out)
# tri_out.backward(dout) tri_out.backward(dout)
# tri_dv, v.grad = v.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None
# tri_dq, q.grad = q.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None
# compare # compare
triton.testing.assert_almost_equal(ref_out, tri_out) triton.testing.assert_almost_equal(ref_out, tri_out)
# triton.testing.assert_almost_equal(ref_dv, tri_dv) triton.testing.assert_almost_equal(ref_dv, tri_dv)
# triton.testing.assert_almost_equal(ref_dk, tri_dk) triton.testing.assert_almost_equal(ref_dk, tri_dk)
# triton.testing.assert_almost_equal(ref_dq, tri_dq) triton.testing.assert_almost_equal(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4 # vary seq length for fixed head and batch=4
@@ -361,4 +360,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms return ms
bench_flash_attention.run(save_path='.', print_data=True) # bench_flash_attention.run(save_path='.', print_data=True)