This commit is contained in:
Philippe Tillet
2023-01-09 22:11:00 -08:00
parent d88353a5a4
commit ff04a5e9b6
4 changed files with 88 additions and 35 deletions

View File

@@ -191,6 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
@@ -260,7 +261,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
# _bwd_kernel[(ctx.grid[1],1,1)](
# _bwd_kernel[(ctx.grid[1], 1, 1)](
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
# o.data_ptr(), do_scaled.data_ptr(),
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),