less math
This commit is contained in:
@@ -1562,6 +1562,7 @@ class CompiledKernel:
|
|||||||
if self.shared > max_shared:
|
if self.shared > max_shared:
|
||||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
|
print(n_regs, n_spills)
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
|
|
||||||
|
@@ -15,7 +15,7 @@ import triton.language as tl
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel(
|
def _fwd_kernel(
|
||||||
Q, K, V, sm_scale,
|
Q, K, V, sm_scale,
|
||||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
L, M,
|
||||||
Out,
|
Out,
|
||||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||||
@@ -39,51 +39,49 @@ def _fwd_kernel(
|
|||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
# initialize pointer to m and l
|
# initialize pointer to m and l
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
# load q: it will stay in SRAM throughout
|
# load q: it will stay in SRAM throughout
|
||||||
q = tl.load(q_ptrs)
|
q = tl.load(q_ptrs)
|
||||||
# loop over k, v and update accumulator
|
# loop over k, v and update accumulator
|
||||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||||
# start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
k = tl.load(k_ptrs)
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||||
# -- compute m_ij, p, l_ij
|
# compute new m
|
||||||
m_ij = tl.max(qk, 1)
|
m = tl.maximum(tl.max(qk, 1), m_prev)
|
||||||
p = tl.exp(qk - m_ij[:, None])
|
# correct old l
|
||||||
l_ij = tl.sum(p, 1)
|
l_prev *= tl.exp(m_prev - m)
|
||||||
# -- update m_i and l_i
|
# attention weights
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
p = tl.exp(qk - m[:, None])
|
||||||
alpha = tl.exp(m_i - m_i_new)
|
l = tl.sum(p, 1) + l_prev
|
||||||
beta = tl.exp(m_ij - m_i_new)
|
l_rcp = 1. / l
|
||||||
l_i_new = alpha * l_i + beta * l_ij
|
# rescale operands of matmuls
|
||||||
# -- update output accumulator --
|
p *= l_rcp
|
||||||
# scale p
|
acc *= (l_prev * l_rcp)[:, None]
|
||||||
p_scale = beta / l_i_new
|
|
||||||
p = p * p_scale[:, None]
|
|
||||||
# scale acc
|
|
||||||
acc_scale = l_i / l_i_new * alpha
|
|
||||||
acc = acc * acc_scale[:, None]
|
|
||||||
# update acc
|
# update acc
|
||||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
|
||||||
p = p.to(tl.float16)
|
p = p.to(tl.float16)
|
||||||
|
v = tl.load(v_ptrs)
|
||||||
acc += tl.dot(p, v)
|
acc += tl.dot(p, v)
|
||||||
# update m_i and l_i
|
# update m_i and l_i
|
||||||
l_i = l_i_new
|
l_prev = l
|
||||||
m_i = m_i_new
|
m_prev = m
|
||||||
|
# update pointers
|
||||||
|
k_ptrs += BLOCK_N * stride_kn
|
||||||
|
v_ptrs += BLOCK_N * stride_vk
|
||||||
# rematerialize offsets to save registers
|
# rematerialize offsets to save registers
|
||||||
start_m = tl.program_id(0)
|
start_m = tl.program_id(0)
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
# write back l and m
|
# write back l and m
|
||||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||||
tl.store(l_ptrs, l_i)
|
tl.store(l_ptrs, l_prev)
|
||||||
tl.store(m_ptrs, m_i)
|
tl.store(m_ptrs, m_prev)
|
||||||
# initialize pointers to output
|
# initialize pointers to output
|
||||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||||
@@ -207,14 +205,13 @@ class _attention(torch.autograd.Function):
|
|||||||
assert Lk in {16, 32, 64, 128}
|
assert Lk in {16, 32, 64, 128}
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
||||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q, k, v, sm_scale,
|
q, k, v, sm_scale,
|
||||||
tmp, L, m,
|
L, m,
|
||||||
o,
|
o,
|
||||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||||
@@ -222,7 +219,7 @@ class _attention(torch.autograd.Function):
|
|||||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||||
q.shape[0], q.shape[1], q.shape[2],
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||||
BLOCK_DMODEL=Lk, num_warps=4,
|
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||||
num_stages=2,
|
num_stages=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -336,6 +333,9 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
|||||||
do = torch.randn_like(o)
|
do = torch.randn_like(o)
|
||||||
fn = lambda: o.backward(do, retain_graph=True)
|
fn = lambda: o.backward(do, retain_graph=True)
|
||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||||
|
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
|
||||||
|
total_flops = 2*flops_per_matmul
|
||||||
|
print(total_flops/ms*1e-9)
|
||||||
return ms
|
return ms
|
||||||
if provider == "flash":
|
if provider == "flash":
|
||||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||||
|
Reference in New Issue
Block a user