Added verifier for trans

This commit is contained in:
Phil Tillet
2023-01-08 14:29:17 -08:00
parent 42421fabc5
commit 6c750b6856
9 changed files with 243 additions and 200 deletions

View File

@@ -191,7 +191,8 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
@@ -259,36 +260,36 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],1,1)](
q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
o.data_ptr(), do_scaled.data_ptr(),
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
l.data_ptr(), m.data_ptr(),
delta.data_ptr(),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0]
)
# pgm = _bwd_kernel[(ctx.grid[1],)](
# q, k, v, ctx.sm_scale,
# o, do_scaled,
# dq, dk, dv,
# l, m,
# delta,
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# _bwd_kernel[(ctx.grid[1],1,1)](
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
# o.data_ptr(), do_scaled.data_ptr(),
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
# l.data_ptr(), m.data_ptr(),
# delta.data_ptr(),
# q.stride(0), q.stride(1), q.stride(2),
# k.stride(0), k.stride(1), k.stride(2),
# v.stride(0), v.stride(1), v.stride(2),
# q.shape[0], q.shape[1], q.shape[2],
# ctx.grid[0],
# BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
# BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
# num_stages=1,
# ctx.grid[0]
# )
# print(pgm.asm["ttgir"])
# exit()
pgm = _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
print(pgm.asm["ttgir"])
exit()
return dq, dk, dv, None