cleanup
This commit is contained in:
@@ -191,7 +191,7 @@ def _bwd_kernel(
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
|
||||
# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
@@ -210,28 +210,28 @@ class _attention(torch.autograd.Function):
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
# _fwd_kernel[grid](
|
||||
# q, k, v, sm_scale,
|
||||
# L, m,
|
||||
# o,
|
||||
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
# o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
# q.shape[0], q.shape[1], q.shape[2],
|
||||
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
# BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
# num_stages=1,
|
||||
# )
|
||||
_fwd_kernel[grid](
|
||||
q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
|
||||
L.data_ptr(), m.data_ptr(),
|
||||
o.data_ptr(),
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
q.shape[0], q.shape[1], q.shape[2])
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
# _fwd_kernel[grid](
|
||||
# q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
|
||||
# L.data_ptr(), m.data_ptr(),
|
||||
# o.data_ptr(),
|
||||
# q.stride(0), q.stride(1), q.stride(2),
|
||||
# k.stride(0), k.stride(1), k.stride(2),
|
||||
# v.stride(0), v.stride(1), v.stride(2),
|
||||
# o.stride(0), o.stride(1), o.stride(2),
|
||||
# q.shape[0], q.shape[1], q.shape[2])
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
|
Reference in New Issue
Block a user