.
This commit is contained in:
@@ -191,6 +191,7 @@ def _bwd_kernel(
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
# _bwd_kernel = triton.compile("./slow.ttgir", num_warps=8)
|
||||
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
|
||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||
@@ -260,7 +261,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
# _bwd_kernel[(ctx.grid[1],1,1)](
|
||||
# _bwd_kernel[(ctx.grid[1], 1, 1)](
|
||||
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||
# o.data_ptr(), do_scaled.data_ptr(),
|
||||
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
||||
|
Reference in New Issue
Block a user