Added TTGIR kernel

This commit is contained in:
Phil Tillet
2022-12-27 21:49:28 -08:00
parent 0d6e6cf578
commit eefc9d1274
2 changed files with 179 additions and 13 deletions

View File

@@ -46,7 +46,6 @@ def _fwd_kernel(
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -192,6 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
empty = torch.empty(128, device="cuda")
@@ -210,19 +210,28 @@ class _attention(torch.autograd.Function):
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# _fwd_kernel[grid](
# q, k, v, sm_scale,
# L, m,
# o,
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# o.stride(0), o.stride(1), o.stride(2), o.stride(3),
# q.shape[0], q.shape[1], q.shape[2],
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# BLOCK_DMODEL=Lk, num_warps=num_warps,
# num_stages=1,
# )
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
L.data_ptr(), m.data_ptr(),
o.data_ptr(),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
q.shape[0], q.shape[1], q.shape[2])
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK