.
This commit is contained in:
@@ -104,8 +104,10 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
|
|||||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||||
blockedLayout.getSizePerThread().end());
|
blockedLayout.getSizePerThread().end());
|
||||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
return {1};
|
auto ret = getSizePerThread(sliceLayout.getParent());
|
||||||
return getSizePerThread(sliceLayout.getParent());
|
return ret;
|
||||||
|
// ret.erase(ret.begin() + sliceLayout.getDim());
|
||||||
|
return ret;
|
||||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
if (mmaLayout.isAmpere()) {
|
if (mmaLayout.isAmpere()) {
|
||||||
return {2, 2};
|
return {2, 2};
|
||||||
|
@@ -162,26 +162,26 @@ def _bwd_kernel(
|
|||||||
q = tl.load(q_ptrs)
|
q = tl.load(q_ptrs)
|
||||||
# recompute p = softmax(qk, dim=-1).T
|
# recompute p = softmax(qk, dim=-1).T
|
||||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||||
qk = tl.dot(q, k, trans_b=True)
|
qk = tl.dot(q, tl.trans(k))
|
||||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||||
m = tl.load(m_ptrs + offs_m_curr)
|
m = tl.load(m_ptrs + offs_m_curr)
|
||||||
p = tl.exp(qk * sm_scale - m[:, None])
|
p = tl.exp(qk * sm_scale - m[:, None])
|
||||||
# compute dv
|
# compute dv
|
||||||
do = tl.load(do_ptrs)
|
do = tl.load(do_ptrs)
|
||||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||||
# compute dp = dot(v, do)
|
# compute dp = dot(v, do)
|
||||||
Di = tl.load(D_ptrs + offs_m_curr)
|
Di = tl.load(D_ptrs + offs_m_curr)
|
||||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||||
dp += tl.dot(do, v, trans_b=True)
|
dp += tl.dot(do, tl.trans(v))
|
||||||
# compute ds = p * (dp - delta[:, None])
|
# compute ds = p * (dp - delta[:, None])
|
||||||
ds = p * dp * sm_scale
|
ds = p * dp * sm_scale
|
||||||
# compute dk = dot(ds.T, q)
|
# compute dk = dot(ds.T, q)
|
||||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||||
# # compute dq
|
# compute dq
|
||||||
dq = tl.load(dq_ptrs)
|
dq = tl.load(dq_ptrs)
|
||||||
dq += tl.dot(ds.to(tl.float16), k)
|
dq += tl.dot(ds.to(tl.float16), k)
|
||||||
tl.store(dq_ptrs, dq)
|
tl.store(dq_ptrs, dq)
|
||||||
# # increment pointers
|
# increment pointers
|
||||||
dq_ptrs += BLOCK_M * stride_qm
|
dq_ptrs += BLOCK_M * stride_qm
|
||||||
q_ptrs += BLOCK_M * stride_qm
|
q_ptrs += BLOCK_M * stride_qm
|
||||||
do_ptrs += BLOCK_M * stride_qm
|
do_ptrs += BLOCK_M * stride_qm
|
||||||
@@ -191,7 +191,7 @@ def _bwd_kernel(
|
|||||||
tl.store(dv_ptrs, dv)
|
tl.store(dv_ptrs, dv)
|
||||||
tl.store(dk_ptrs, dk)
|
tl.store(dk_ptrs, dk)
|
||||||
|
|
||||||
# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
|
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
||||||
|
|
||||||
empty = torch.empty(128, device="cuda")
|
empty = torch.empty(128, device="cuda")
|
||||||
|
|
||||||
@@ -210,7 +210,7 @@ class _attention(torch.autograd.Function):
|
|||||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
h = _fwd_kernel[grid](
|
||||||
q, k, v, sm_scale,
|
q, k, v, sm_scale,
|
||||||
L, m,
|
L, m,
|
||||||
o,
|
o,
|
||||||
@@ -255,8 +255,7 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
|
pgm = _bwd_kernel[(ctx.grid[1],)](
|
||||||
_bwd_kernel[(ctx.grid[1],)](
|
|
||||||
q, k, v, ctx.sm_scale,
|
q, k, v, ctx.sm_scale,
|
||||||
o, do_scaled,
|
o, do_scaled,
|
||||||
dq, dk, dv,
|
dq, dk, dv,
|
||||||
@@ -268,7 +267,7 @@ class _attention(torch.autograd.Function):
|
|||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
ctx.grid[0],
|
ctx.grid[0],
|
||||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
@@ -281,8 +280,8 @@ attention = _attention.apply
|
|||||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||||
torch.manual_seed(20)
|
torch.manual_seed(20)
|
||||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_()
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.1).requires_grad_()
|
||||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.1).requires_grad_()
|
||||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.1).requires_grad_()
|
||||||
sm_scale = 0.2
|
sm_scale = 0.2
|
||||||
dout = torch.randn_like(q)
|
dout = torch.randn_like(q)
|
||||||
# reference implementation
|
# reference implementation
|
||||||
@@ -294,23 +293,23 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|||||||
p = torch.softmax(p.float(), dim=-1).half()
|
p = torch.softmax(p.float(), dim=-1).half()
|
||||||
# p = torch.exp(p)
|
# p = torch.exp(p)
|
||||||
ref_out = torch.matmul(p, v)
|
ref_out = torch.matmul(p, v)
|
||||||
# ref_out.backward(dout)
|
ref_out.backward(dout)
|
||||||
# ref_dv, v.grad = v.grad.clone(), None
|
ref_dv, v.grad = v.grad.clone(), None
|
||||||
# ref_dk, k.grad = k.grad.clone(), None
|
ref_dk, k.grad = k.grad.clone(), None
|
||||||
# ref_dq, q.grad = q.grad.clone(), None
|
ref_dq, q.grad = q.grad.clone(), None
|
||||||
# # triton implementation
|
# # triton implementation
|
||||||
tri_out = attention(q, k, v, sm_scale)
|
tri_out = attention(q, k, v, sm_scale)
|
||||||
# print(ref_out)
|
# print(ref_out)
|
||||||
# print(tri_out)
|
# print(tri_out)
|
||||||
# tri_out.backward(dout)
|
tri_out.backward(dout)
|
||||||
# tri_dv, v.grad = v.grad.clone(), None
|
tri_dv, v.grad = v.grad.clone(), None
|
||||||
# tri_dk, k.grad = k.grad.clone(), None
|
tri_dk, k.grad = k.grad.clone(), None
|
||||||
# tri_dq, q.grad = q.grad.clone(), None
|
tri_dq, q.grad = q.grad.clone(), None
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||||
# triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||||
# triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||||
# triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||||
|
|
||||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||||
# vary seq length for fixed head and batch=4
|
# vary seq length for fixed head and batch=4
|
||||||
@@ -361,4 +360,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||||
return ms
|
return ms
|
||||||
|
|
||||||
bench_flash_attention.run(save_path='.', print_data=True)
|
# bench_flash_attention.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user